Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

MNIST

mnist.livemd

MNIST

Introduction

This livebook will walk you through training a basic neural network using Axon, accelerated by the EXLA compiler. We’ll be working on the MNIST dataset which is a dataset of handwritten digits with corresponding labels. The goal is to train a model that correctly classifies these handwritten digits with a single label [0-9].

Dependencies

First, we’ll need to install our dependencies using Mix.install. We’ll need Axon and it’s dependencies, as well as the Req library for downloading the dataset.

Mix.install([
  {:req, "~> 0.3.0-dev", github: "wojtekmach/req", branch: "main"},
  {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main", override: true},
  {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla", override: true}
])

Retrieving and Exploring Dataset

The MNIST dataset is available for free online. Using Req we’ll download both training images and training labels. Both train_images and train_labels are compressed binary data. Fortunately, Req takes care of the decompression for us.

You can read more about the format of the ubyte files here. Each file starts with a magic number and some metadata. We can use binary pattern matching to extract the information we want. In this case we extract the raw binary images and labels.

base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
%{body: train_images} = Req.get!(base_url <> "train-images-idx3-ubyte.gz")
%{body: train_labels} = Req.get!(base_url <> "train-labels-idx1-ubyte.gz")

<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = train_images
<<_::32, n_labels::32, labels::binary>> = train_labels

We can easily read that binary data into a tensor using Nx.from_binary/2. Nx.from_binary/2 expects a raw binary and a data type. In this case, both images and labels are stored as unsigned 8-bit integers. We can start by parsing our images:

images =
  images
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
  |> Nx.divide(255)

Nx.from_binary/2 returns a flat tensor. Using Nx.reshape/3 we can manipulate this flat tensor into meaningful dimensions. Notice we also normalized the tensor by dividing the input data by 255. This squeezes the data between 0 and 1 which often leads to better behavior when training models. Now, let’s see what these images look like:

images[[images: 0..4]] |> Nx.to_heatmap()

In the reshape operation above, we give each dimension of the tensor a name. This makes it much easier to do things like slicing, and helps make your code easier to understand. Here we slice the images dimension of the images tensor to obtain the first 5 training images. Then, we convert them to a heatmap for easy visualization.

It’s common to train neural networks in batches (actually correctly called minibatches, but you’ll see batch and minibatch used interchangeably). We can “batch” our images into batches of 32 like this:

images =
  images
  |> Nx.to_batched_list(32)

Now, we’ll need to get our labels into batches as well, but first we need to one-hot encode the labels. One-hot encoding converts input data from labels such as 3, 5, 7, etc. into vectors of 0’s and a single 1 at the correct labels index. As an example, a label of: 3 gets converted to: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0].

targets =
  labels
  |> Nx.from_binary({:u, 8})
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched_list(32)

Defining the Model

Let’s start by defining a simple model:

model =
  Axon.input({nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

All Axon models start with an input layer to tell subsequent layers what shapes to expect. We then use Axon.flatten/2 which flattens the previous layer by squeezing all dimensions but the first dimension into a single dimension. Our model consists of 2 fully connected layers with 128 and 10 units respectively. The first layer uses :relu activation which returns max(0, input) element-wise. The final layer uses :softmax activation to return a probability distribution over the 10 labels [0 - 9].

Training

Axon boils the task of training down to defining a training step and passing the step to a training loop. You can use Axon.Training.step/3 to create a generic training step with a model, a loss function, and an optimizer. In this example, we’ll use categorical cross-entropy and the Adam optimizer. You can then pass this to a training loop with your training data to train the final model. Axon.Training.train/4 lets you specify some additional options as well, such as the Nx compiler to use. In this example we’ll train for 10 epochs using the EXLA compiler, logging metrics every 100 training steps.

model
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adam(0.01))
|> Axon.Training.train(images, targets, compiler: EXLA, epochs: 10, log_every: 100)