Axon examples
Mix.install([
{:axon, "~> 0.6"},
{:nx, "~> 0.7"},
{:exla, "~> 0.5"},
{:scidata, "~> 0.1"},
{:kino, "~> 0.8"},
{:table_rex, "~> 3.1"}
])
Example 1
Example come from Sean Moriarity.
Download samples & convert to tensor type.
# set backend to XLA.
Nx.default_backend(EXLA.Backend)
# 60k 28x29 greyscale images.
{images, lables} = Scidata.MNIST.download()
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = lables
# convert to tensor.
images =
image_data
|> Nx.from_binary(image_type)
|> Nx.divide(255)
|> Nx.reshape({60000, :auto})
labels =
label_data
|> Nx.from_binary(label_type)
|> Nx.reshape(label_shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
{images, labels}
Split data to 2 parts, one for training, one for testing
# Split data for training & testing.
train_range = 0..49_999//1
test_range = 50_000..59_999//1
train_images = images[train_range]
train_labels = labels[train_range]
test_images = images[test_range]
test_labels = labels[test_range]
# convert to mini batches
batch_size = 64
train_data =
train_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(train_labels, batch_size))
test_data =
test_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(test_labels, batch_size))
Create model for neutral network.
model =
Axon.input("images", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
Show model for easy understand.
template = Nx.template({1, 784}, :f32)
Axon.Display.as_graph(model, template)
Training model with train_data from last code.
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)
Verify the model after trained.
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
Test model with an image.
{test_batch, _} = Enum.at(test_data, 0)
test_image = test_batch[31]
test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
Run image with model.
{_, predict_fn} = Axon.build(model, compiler: EXLA)
result =
test_image
|> Nx.new_axis(0)
|> then(&predict_fn.(trained_model_state, &1))
|> Nx.argmax()