Ch 6.2: Axon
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:scidata, "~> 0.1"},
{:kino, "~> 0.8"},
{:table_rex, "~> 3.1.1"}
])
Nx.default_backend(EXLA.Backend)
Set up data
# Get MNIST images and labels
{images, labels} = Scidata.MNIST.download()
Both images and labels consist of tuples of the form {data, type, shape}.
Images: type is {:u, 8}. shape is {60_000, 1, 28, 28}.
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels
images =
image_data
|> Nx.from_binary(image_type)
# divide by 255px to normalise
|> Nx.divide(255)
|> Nx.reshape({60000, :auto})
labels =
label_data
|> Nx.from_binary(label_type)
|> Nx.reshape(label_shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
# Note that the features (dimension 1) will match model input size.
images.shape
Split into test and training sets
train_range = 0..49_999//1
test_range = 50_000..-1//1
train_images = images[train_range]
train_labels = labels[train_range]
test_images = images[test_range]
test_labels = labels[test_range]
Split into batches.
This code creates both train and test datasets that consist of minibatches of tuples {input, target}—which is the format expected by Axon.
Note the use of Stream to lazily zip the tensors.
batch_size = 64
train_data =
train_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(train_labels, batch_size))
test_data =
test_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(test_labels, batch_size))
Build model
The model takes an input shape of {nil, 784}. Axon allows you to use nil as a placeholder for values that will be filled at inference time. The input layer is passed through a hidden dense layer with 128 units and a :relu activation before going through another dense layer with 10 units and :softmax activation.
The input layer with shape {nil, 784} maps directly to input images that are batches of vectors of dimensionality 784.
model =
Axon.input("images", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
template = Nx.template({1, 784}, :f32)
Axon.Display.as_graph(model, template)
Axon.Display.as_table(model, template)
|> IO.puts()
Training
Axon’s training abstraction lies in the Axon.Loop module.
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)
Evaluating the Model
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
Executing
{test_batch, _ } = Enum.at(test_data, 0)
test_image = test_batch[0]
test_image
|> Nx.reshape({28,28})
|> Nx.to_heatmap()
The easiest way to query your model for predictions is to first build your model using Axon.build/2 and then call the returned predict function. Axon.build/2 converts your model into a tuple of {init_fn, predict_fn}. init_fn is an arity-2 function that can be used to initialize your model’s parameters. predict_fn is an arity-2 function that takes model parameters and a tensor or collection of tensors as input and returns the result of running the full model.
{_, predict_fn} = Axon.build(model, compiler: EXLA)
probabilities =
test_image
|> Nx.new_axis(0) # add an axis to match input shape
|> then(&predict_fn.(trained_model_state, &1))
# Not required. Just eyeballing the returned probabilities mapped to an index
p = Nx.to_flat_list(probabilities)
Enum.max(p) |> IO.inspect()
Enum.zip(0..9, p)
|> Enum.into(%{})
probabilities
|> Nx.argmax()