Chapter 19: beyond vanilla networks


    {:exla, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:axon, "~> 0.5"},
    {:kino, "~> 0.8.1"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:vega_lite, "~> 0.1.6"},
    {:scidata, "~> 0.1"},
    {:nx_image, "~> 0.1.0"},
    {:table_rex, "~> 3.1.1"}
  config: [nx: [default_backend: EXLA.Backend]]

The CIFAR-10 Dataset

What CIFAR-10 looks like

  {images_binary, images_type, images_shape},
  {labels_binary, labels_type, labels_shape}
} = Scidata.CIFAR10.download()
images =
  |> Nx.from_binary(images_type)
  |> Nx.reshape({elem(images_shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])

labels =
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)

columns = 8
rows = 4

key = Nx.Random.key(42)

# Compute random indices
indices =
  {elem(images_shape, 0) - 1}
  |> Nx.iota()
  |> then(fn data ->
    {shuffled_data, _new_key} = Nx.Random.shuffle(key, data)
  |> Nx.slice_along_axis(0, columns * rows)

selected_images = Nx.take(images, indices)
selected_labels = Nx.take(labels, indices)

  Enum.map(0..(columns * rows - 1), fn i ->
        Kino.Markdown.new("class: #{selected_labels[i][0] |> Nx.to_number()}"),
        # transpose the image since `Kino.Image.new` requires
        # the following shape `{:h, :w, :c}`, while the original one is `{:c, :h, :w}`
        |> Nx.transpose(axes: [:h, :w, :c])
        |> NxImage.resize({100, 100}, method: :nearest)
        |> Kino.Image.new()
      boxed: true,
      columns: 1
  boxed: true,
  columns: columns

Falling short of CIFAR

Prepare the data

defmodule Chapter19.CIFAR10 do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels

  defp transform_images({bin, type, shape}) do
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), :auto})
    |> Nx.divide(255.0)

  defp transform_labels({bin, type, _}) do
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

  defp split(tensor) do
    {x, _} = Nx.shape(tensor)
    len = trunc(x / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}

Load the data and prepare the train batches and validation dataset.

  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the model and train it

epochs = 25

model =
  |> Axon.dense(1200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(500)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(10, activation: :softmax)

|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 2250 seconds ca.

Results after 25 epochs:

  • accuracy: 0.8000934 - loss: 1.0527945
  • validation accuracy: 0.4242000 - validation loss: 5.1675453

Running on Convolutions

Prepare the data

Compared to the previous implementation, the images (inputs) are not flatten, but all the 3 dimensions (channel, height, width) are kept.

defmodule Chapter19.CIFAR10Cnn do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels

  defp transform_images({bin, type, shape}) do
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
    |> Nx.transpose(axes: [:n, :h, :w, :c])
    |> Nx.divide(255.0)

  defp transform_labels({bin, type, _}) do
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

  defp split(%Nx.Tensor{shape: {n, _c, _h, _w}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: :n)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: :n)
    {first_half, second_half}

  defp split(%Nx.Tensor{shape: {n, _}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}

Load the data and prepare the train batches and validation dataset.

  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10Cnn.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the CNN and train it

epochs = 20

model =
  Axon.input("data", shape: {nil, 32, 32, 3})
  |> Axon.conv(16, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.conv(32, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.flatten()
  |> Axon.dense(1000, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(512, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

Axon.Display.as_table(model, Nx.to_template(validation_images)) |> IO.puts()

|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 8400 seconds ca.

Results after 20 epochs:

  • accuracy: 0.8096849 - loss: 0.8324687
  • validation accuracy: 0.6632000 - validation loss: 1.5468998

Channels :first VS :last

In the Chapter19.CIFAR10Cnn the inputs extracted from the binary are transposed to have the channels axis as last:

defp transform_images({bin, type, shape}) do
  |> Nx.from_binary(type)
  |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
  |> Nx.transpose(axes: [:n, :h, :w, :c])
  |> Nx.divide(255.0)

The Axon.conv API already expects the channels axis to be last one, therefore there is no need to set the option explicitly.

But it is worth mentioning that keeping the channels as first axis (by skipping the transposition) is a possibility.

I tried both the approaches and having the channels axis as last led to better accuracy overall:

Channels position training accuracy validation accuracy
:first 0.7166188 0.6012000
:last 0.8096849 0.6632000

Then, to conclude the Keras implementation in the book is equivalent to the Axon implementation with channels as :last axis.

Keras model summary

The model

model = Sequential()

model.add(Conv2D(16, (3, 3), activation='relu'))

model.add(Conv2D(32, (3, 3), activation='relu'))


model.add(Dense(1000, activation='relu'))

model.add(Dense(512, activation='relu'))

model.add(Dense(10, activation='softmax'))

Model: "sequential"
 Layer (type)                Output Shape              Param #
 conv2d (Conv2D)             (50000, 30, 30, 16)       448

 batch_normalization (BatchN  (50000, 30, 30, 16)      64

 dropout (Dropout)           (50000, 30, 30, 16)       0

 conv2d_1 (Conv2D)           (50000, 28, 28, 32)       4640

 batch_normalization_1 (Batc  (50000, 28, 28, 32)      128

 dropout_1 (Dropout)         (50000, 28, 28, 32)       0

 flatten (Flatten)           (50000, 25088)            0

 dense (Dense)               (50000, 1000)             25089000

 batch_normalization_2 (Batc  (50000, 1000)            4000

 dropout_2 (Dropout)         (50000, 1000)             0

 dense_1 (Dense)             (50000, 512)              512512

 batch_normalization_3 (Batc  (50000, 512)             2048

 dropout_3 (Dropout)         (50000, 512)              0

 dense_2 (Dense)             (50000, 10)               5130

Total params: 25,617,970
Trainable params: 25,614,850
Non-trainable params: 3,120

Axon model summary

|                                                                  Model                                                                  |
| Layer                                 | Input Shape          | Output Shape       | Options                 | Parameters                |
| data ( input )                        | []                   | {5000, 32, 32, 3}  | shape: {nil, 32, 32, 3} |                           |
|                                       |                      |                    | optional: false         |                           |
| conv_0 ( conv["data"] )               | [{5000, 32, 32, 3}]  | {5000, 30, 30, 16} | strides: 1              | kernel: f32[3][3][3][16]  |
|                                       |                      |                    | padding: :valid         | bias: f32[16]             |
|                                       |                      |                    | input_dilation: 1       |                           |
|                                       |                      |                    | kernel_dilation: 1      |                           |
|                                       |                      |                    | feature_group_size: 1   |                           |
|                                       |                      |                    | channels: :last         |                           |
| relu_0 ( relu["conv_0"] )             | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} |                         |                           |
| batch_norm_0 ( batch_norm["relu_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | epsilon: 1.0e-5         | gamma: f32[16]            |
|                                       |                      |                    | channel_index: -1       | beta: f32[16]             |
|                                       |                      |                    | momentum: 0.1           | mean: f32[16]             |
|                                       |                      |                    |                         | var: f32[16]              |
| dropout_0 ( dropout["batch_norm_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | rate: 0.5               | key: f32[2]               |
| conv_1 ( conv["dropout_0"] )          | [{5000, 30, 30, 16}] | {5000, 28, 28, 32} | strides: 1              | kernel: f32[3][3][16][32] |
|                                       |                      |                    | padding: :valid         | bias: f32[32]             |
|                                       |                      |                    | input_dilation: 1       |                           |
|                                       |                      |                    | kernel_dilation: 1      |                           |
|                                       |                      |                    | feature_group_size: 1   |                           |
|                                       |                      |                    | channels: :last         |                           |
| relu_1 ( relu["conv_1"] )             | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} |                         |                           |
| batch_norm_1 ( batch_norm["relu_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | epsilon: 1.0e-5         | gamma: f32[32]            |
|                                       |                      |                    | channel_index: -1       | beta: f32[32]             |
|                                       |                      |                    | momentum: 0.1           | mean: f32[32]             |
|                                       |                      |                    |                         | var: f32[32]              |
| dropout_1 ( dropout["batch_norm_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | rate: 0.5               | key: f32[2]               |
| flatten_0 ( flatten["dropout_1"] )    | [{5000, 28, 28, 32}] | {5000, 25088}      |                         |                           |
| dense_0 ( dense["flatten_0"] )        | [{5000, 25088}]      | {5000, 1000}       |                         | kernel: f32[25088][1000]  |
|                                       |                      |                    |                         | bias: f32[1000]           |
| relu_2 ( relu["dense_0"] )            | [{5000, 1000}]       | {5000, 1000}       |                         |                           |
| batch_norm_2 ( batch_norm["relu_2"] ) | [{5000, 1000}]       | {5000, 1000}       | epsilon: 1.0e-5         | gamma: f32[1000]          |
|                                       |                      |                    | channel_index: -1       | beta: f32[1000]           |
|                                       |                      |                    | momentum: 0.1           | mean: f32[1000]           |
|                                       |                      |                    |                         | var: f32[1000]            |
| dropout_2 ( dropout["batch_norm_2"] ) | [{5000, 1000}]       | {5000, 1000}       | rate: 0.5               | key: f32[2]               |
| dense_1 ( dense["dropout_2"] )        | [{5000, 1000}]       | {5000, 512}        |                         | kernel: f32[1000][512]    |
|                                       |                      |                    |                         | bias: f32[512]            |
| relu_3 ( relu["dense_1"] )            | [{5000, 512}]        | {5000, 512}        |                         |                           |
| batch_norm_3 ( batch_norm["relu_3"] ) | [{5000, 512}]        | {5000, 512}        | epsilon: 1.0e-5         | gamma: f32[512]           |
|                                       |                      |                    | channel_index: -1       | beta: f32[512]            |
|                                       |                      |                    | momentum: 0.1           | mean: f32[512]            |
|                                       |                      |                    |                         | var: f32[512]             |
| dropout_3 ( dropout["batch_norm_3"] ) | [{5000, 512}]        | {5000, 512}        | rate: 0.5               | key: f32[2]               |
| dense_2 ( dense["dropout_3"] )        | [{5000, 512}]        | {5000, 10}         |                         | kernel: f32[512][10]      |
|                                       |                      |                    |                         | bias: f32[10]             |
| softmax_0 ( softmax["dense_2"] )      | [{5000, 10}]         | {5000, 10}         |                         |                           |
Total Parameters: 25617978
Total Parameters Memory: 102471912 bytes