Classifying handwritten digits
# from https://raw.githubusercontent.com/elixir-nx/axon/v0.6.0/notebooks/vision/mnist.livemd
Mix.install(
[
{:axon, "~> 0.6"},
{:nx, "~> 0.6"},
{:exla, "~> 0.6"},
{:req, "~> 0.3.1"}
],
config: [
nx: [
default_backend: EXLA.Backend,
default_defn_options: [compiler: EXLA]
],
exla: [
default_client: :cuda,
clients: [
host: [platform: :host],
cuda: [platform: :cuda]
]
]
],
system_env: [
XLA_TARGET: "cuda12"
]
)
Introduction
This livebook will walk you through training a basic neural network using Axon, accelerated by the EXLA compiler. We’ll be working on the MNIST dataset which is a dataset of handwritten digits with corresponding labels. The goal is to train a model that correctly classifies these handwritten digits with a single label [0-9].
Retrieving and exploring the dataset
The MNIST dataset is available for free online. Using Req
we’ll download both training images and training labels. Both train_images
and train_labels
are compressed binary data. Fortunately, Req
takes care of the decompression for us.
You can read more about the format of the ubyte files here. Each file starts with a magic number and some metadata. We can use binary pattern matching to extract the information we want. In this case we extract the raw binary images and labels.
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
%{body: train_images} = Req.get!(base_url <> "train-images-idx3-ubyte.gz")
%{body: train_labels} = Req.get!(base_url <> "train-labels-idx1-ubyte.gz")
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = train_images
<<_::32, n_labels::32, labels::binary>> = train_labels
We can easily read that binary data into a tensor using Nx.from_binary/2
. Nx.from_binary/2
expects a raw binary and a data type. In this case, both images and labels are stored as unsigned 8-bit integers. We can start by parsing our images:
images =
images
|> Nx.from_binary({:u, 8})
|> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
|> Nx.divide(255)
Nx.from_binary/2
returns a flat tensor. Using Nx.reshape/3
we can manipulate this flat tensor into meaningful dimensions. Notice we also normalized the tensor by dividing the input data by 255. This squeezes the data between 0 and 1 which often leads to better behavior when training models. Now, let’s see what these images look like:
images[[images: 0..4]] |> Nx.to_heatmap()
In the reshape operation above, we give each dimension of the tensor a name. This makes it much easier to do things like slicing, and helps make your code easier to understand. Here we slice the images
dimension of the images tensor to obtain the first 5 training images. Then, we convert them to a heatmap for easy visualization.
It’s common to train neural networks in batches (actually correctly called minibatches, but you’ll see batch and minibatch used interchangeably). We can “batch” our images into batches of 32 like this:
images = Nx.to_batched(images, 32)
Now, we’ll need to get our labels into batches as well, but first we need to one-hot encode the labels. One-hot encoding converts input data from labels such as 3
, 5
, 7
, etc. into vectors of 0’s and a single 1 at the correct labels index. As an example, a label of: 3
gets converted to: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
.
targets =
labels
|> Nx.from_binary({:u, 8})
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched(32)
Defining the model
Let’s start by defining a simple model:
model =
Axon.input("input", shape: {nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
All Axon
models start with an input layer to tell subsequent layers what shapes to expect. We then use Axon.flatten/2
which flattens the previous layer by squeezing all dimensions but the first dimension into a single dimension. Our model consists of 2 fully connected layers with 128 and 10 units respectively. The first layer uses :relu
activation which returns max(0, input)
element-wise. The final layer uses :softmax
activation to return a probability distribution over the 10 labels [0 - 9].
Training
In Axon we express the task of training using a declarative loop API. First, we need to specify a loss function and optimizer, there are many built-in variants to choose from. In this example, we’ll use categorical cross-entropy and the Adam optimizer. We will also keep track of the accuracy metric. Finally, we run training loop passing our batched images and labels. We’ll train for 10 epochs using the EXLA
compiler.
params =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(Stream.zip(images, targets), %{}, epochs: 10, compiler: EXLA)
Prediction
Now that we have the parameters from the training step, we can use them for predictions.
For this the Axon.predict
can be used.
first_batch = Enum.at(images, 0)
output = Axon.predict(model, params, first_batch)
For each image, the model outputs probability distribution. This informs us how certain the model is about its prediction. Let’s see the most probable digit for each image:
Nx.argmax(output, axis: 1)
If you look at the original images and you will see the predictions match the data!