Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

MNIST with Axon

16_deeper/mnist_with_axon.livemd

MNIST with Axon

Mix.install(
  [
    {:exla, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:vega_lite, "~> 0.1.6"},
    {:kino, "~> 0.8.1"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Prepare and load MNIST dataset

inspired by https://hexdocs.pm/axon/mnist.html#introduction

defmodule C16.MNISTDataset do
  @moduledoc """
  Use this Module to load the MNIST database (test and validation sets) with
  normalized inputs.

  MNIST dataset specifications can be found here: http://yann.lecun.com/exdb/mnist/
  """

  @data_path Path.join(__DIR__, "../data/mnist") |> Path.expand()

  @train_images_filename Path.join(@data_path, "train-images-idx3-ubyte.gz")
  @test_images_filename Path.join(@data_path, "t10k-images-idx3-ubyte.gz")
  @train_labels_filename Path.join(@data_path, "train-labels-idx1-ubyte.gz")
  @test_labels_filename Path.join(@data_path, "t10k-labels-idx1-ubyte.gz")

  @type t :: %__MODULE__{
          x_train: Nx.Tensor.t(),
          x_validation: Nx.Tensor.t(),
          y_train: Nx.Tensor.t(),
          y_validation: Nx.Tensor.t()
        }
  defstruct [
    :x_train,
    :x_validation,
    :y_train,
    :y_validation
  ]

  @doc """
  Load the MNIST database and return a map with train and validation images/labels.

  * train and validation images normalized (`x_train` and `x_test`)
  * `y_train` and `y_validation` one-hot encoded
  """
  @spec load() :: t()
  def load() do
    # 60000 images, 1 channel, 28 pixel width, 28 pixel height
    train_images = load_images(@train_images_filename)
    validation_images = load_images(@test_images_filename)

    # 10000 labels, one-hot encoded
    train_labels = load_labels(@train_labels_filename)
    validation_labels = load_labels(@test_labels_filename)

    %__MODULE__{
      x_train: train_images,
      x_validation: validation_images,
      y_train: train_labels,
      y_validation: validation_labels
    }
  end

  defp load_labels(filename) do
    # Open and unzip the file of labels
    with {:ok, binary} <- File.read(filename) do
      <<_::32, n_labels::32, labels_binary::binary>> = :zlib.gunzip(binary)

      # Nx.from_binary/2 returns a flat tensor.
      # With Nx.reshape/3 we can manipulate this flat tensor
      # and reshape it: 1 row for each image, each row composed by
      # one column:
      # [
      #   [1],
      #   [4],
      #   [9],
      #   …
      # ]
      labels_binary
      |> Nx.from_binary({:u, 8})
      |> Nx.reshape({n_labels, 1})
      |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
    end
  end

  defp load_images(filename) do
    # Open and unzip the file of labels
    with {:ok, binary} <- File.read(filename) do
      <<_::32, n_images::32, n_rows::32, n_cols::32, images_binary::binary>> =
        :zlib.gunzip(binary)

      # Nx.from_binary/2 returns a flat tensor.
      # Using Nx.reshape/3 we can manipulate this flat tensor into meaningful dimensions.
      # Notice we also normalized the tensor by dividing the input data by 255.
      # This squeezes the data between 0 and 1 which often leads to better behavior when
      # training models.
      # https://hexdocs.pm/axon/mnist.html#introduction
      images_binary
      |> Nx.from_binary({:u, 8})
      |> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
      |> Nx.divide(255)
    end
  end
end

Visualize the dataset via heatmap

We slice the images dimension of the images tensor to obtain the first 5 training images. Then, we convert them to a heatmap for easy visualization.

%{x_train: x_train} = C16.MNISTDataset.load()

x_train[[images: 0..4]] |> Nx.to_heatmap()

Build the model and train the network

# The `Axon.flatten` layer will flatten all but the batch dimensions of the
# input into a single layer. Typically called to flatten the output
# of a convolution for use with a dense layer.
#
# https://hexdocs.pm/axon/Axon.html#flatten/2
#
# Flattening is converting the data into a 1-dimensional array for
# inputting it to the next layer.
# From `{60_000, 1, 28, 28}` to `{60_000, 784}`

model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(1200, activation: :sigmoid)
  |> Axon.dense(10, activation: :softmax)

Axon.Display.as_table(model, Nx.to_template(x_train)) |> IO.puts()

Axon.Display.as_graph(model, Nx.to_template(x_train))
%{
  x_train: x_train,
  y_train: y_train,
  x_validation: x_validation,
  y_validation: y_validation
} = C16.MNISTDataset.load()

# Batch the training data
train_data = Stream.zip(Nx.to_batched(x_train, 32), Nx.to_batched(y_train, 32))

validation_data = [{x_validation, y_validation}]

params =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.1))
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.validate(model, validation_data)
  |> Axon.Loop.run(train_data, %{}, epochs: 50, compiler: EXLA)