MNIST with Axon
Mix.install(
[
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:vega_lite, "~> 0.1.6"},
{:kino, "~> 0.8.1"},
{:kino_vega_lite, "~> 0.1.7"},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Prepare and load MNIST dataset
inspired by https://hexdocs.pm/axon/mnist.html#introduction
defmodule C16.MNISTDataset do
@moduledoc """
Use this Module to load the MNIST database (test and validation sets) with
normalized inputs.
MNIST dataset specifications can be found here: http://yann.lecun.com/exdb/mnist/
"""
@data_path Path.join(__DIR__, "../data/mnist") |> Path.expand()
@train_images_filename Path.join(@data_path, "train-images-idx3-ubyte.gz")
@test_images_filename Path.join(@data_path, "t10k-images-idx3-ubyte.gz")
@train_labels_filename Path.join(@data_path, "train-labels-idx1-ubyte.gz")
@test_labels_filename Path.join(@data_path, "t10k-labels-idx1-ubyte.gz")
@type t :: %__MODULE__{
x_train: Nx.Tensor.t(),
x_validation: Nx.Tensor.t(),
y_train: Nx.Tensor.t(),
y_validation: Nx.Tensor.t()
}
defstruct [
:x_train,
:x_validation,
:y_train,
:y_validation
]
@doc """
Load the MNIST database and return a map with train and validation images/labels.
* train and validation images normalized (`x_train` and `x_test`)
* `y_train` and `y_validation` one-hot encoded
"""
@spec load() :: t()
def load() do
# 60000 images, 1 channel, 28 pixel width, 28 pixel height
train_images = load_images(@train_images_filename)
validation_images = load_images(@test_images_filename)
# 10000 labels, one-hot encoded
train_labels = load_labels(@train_labels_filename)
validation_labels = load_labels(@test_labels_filename)
%__MODULE__{
x_train: train_images,
x_validation: validation_images,
y_train: train_labels,
y_validation: validation_labels
}
end
defp load_labels(filename) do
# Open and unzip the file of labels
with {:ok, binary} <- File.read(filename) do
<<_::32, n_labels::32, labels_binary::binary>> = :zlib.gunzip(binary)
# Nx.from_binary/2 returns a flat tensor.
# With Nx.reshape/3 we can manipulate this flat tensor
# and reshape it: 1 row for each image, each row composed by
# one column:
# [
# [1],
# [4],
# [9],
# …
# ]
labels_binary
|> Nx.from_binary({:u, 8})
|> Nx.reshape({n_labels, 1})
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
end
defp load_images(filename) do
# Open and unzip the file of labels
with {:ok, binary} <- File.read(filename) do
<<_::32, n_images::32, n_rows::32, n_cols::32, images_binary::binary>> =
:zlib.gunzip(binary)
# Nx.from_binary/2 returns a flat tensor.
# Using Nx.reshape/3 we can manipulate this flat tensor into meaningful dimensions.
# Notice we also normalized the tensor by dividing the input data by 255.
# This squeezes the data between 0 and 1 which often leads to better behavior when
# training models.
# https://hexdocs.pm/axon/mnist.html#introduction
images_binary
|> Nx.from_binary({:u, 8})
|> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
|> Nx.divide(255)
end
end
end
Visualize the dataset via heatmap
We slice the images dimension of the images tensor to obtain the first 5 training images. Then, we convert them to a heatmap for easy visualization.
%{x_train: x_train} = C16.MNISTDataset.load()
x_train[[images: 0..4]] |> Nx.to_heatmap()
Build the model and train the network
# The `Axon.flatten` layer will flatten all but the batch dimensions of the
# input into a single layer. Typically called to flatten the output
# of a convolution for use with a dense layer.
#
# https://hexdocs.pm/axon/Axon.html#flatten/2
#
# Flattening is converting the data into a 1-dimensional array for
# inputting it to the next layer.
# From `{60_000, 1, 28, 28}` to `{60_000, 784}`
model =
Axon.input("input", shape: {nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(1200, activation: :sigmoid)
|> Axon.dense(10, activation: :softmax)
Axon.Display.as_table(model, Nx.to_template(x_train)) |> IO.puts()
Axon.Display.as_graph(model, Nx.to_template(x_train))
%{
x_train: x_train,
y_train: y_train,
x_validation: x_validation,
y_validation: y_validation
} = C16.MNISTDataset.load()
# Batch the training data
train_data = Stream.zip(Nx.to_batched(x_train, 32), Nx.to_batched(y_train, 32))
validation_data = [{x_validation, y_validation}]
params =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.1))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_data, %{}, epochs: 50, compiler: EXLA)