Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 19: beyond vanilla networks

19_beyond/beyond_vanilla_networks.livemd

Chapter 19: beyond vanilla networks

Mix.install(
  [
    {:exla, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:axon, "~> 0.5"},
    {:kino, "~> 0.8.1"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:vega_lite, "~> 0.1.6"},
    {:scidata, "~> 0.1"},
    {:nx_image, "~> 0.1.0"},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

The CIFAR-10 Dataset

What CIFAR-10 looks like

{
  {images_binary, images_type, images_shape},
  {labels_binary, labels_type, labels_shape}
} = Scidata.CIFAR10.download()
images =
  images_binary
  |> Nx.from_binary(images_type)
  |> Nx.reshape({elem(images_shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])

labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)

columns = 8
rows = 4

key = Nx.Random.key(42)

# Compute random indices
indices =
  {elem(images_shape, 0) - 1}
  |> Nx.iota()
  |> then(fn data ->
    {shuffled_data, _new_key} = Nx.Random.shuffle(key, data)
    shuffled_data
  end)
  |> Nx.slice_along_axis(0, columns * rows)

selected_images = Nx.take(images, indices)
selected_labels = Nx.take(labels, indices)

Kino.Layout.grid(
  Enum.map(0..(columns * rows - 1), fn i ->
    Kino.Layout.grid(
      [
        Kino.Markdown.new("class: #{selected_labels[i][0] |> Nx.to_number()}"),
        selected_images[i]
        # transpose the image since `Kino.Image.new` requires
        # the following shape `{:h, :w, :c}`, while the original one is `{:c, :h, :w}`
        |> Nx.transpose(axes: [:h, :w, :c])
        |> NxImage.resize({100, 100}, method: :nearest)
        |> Kino.Image.new()
      ],
      boxed: true,
      columns: 1
    )
  end),
  boxed: true,
  columns: columns
)

Falling short of CIFAR

Prepare the data

defmodule Chapter19.CIFAR10 do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

    %{
      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels
    }
  end

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), :auto})
    |> Nx.divide(255.0)
  end

  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  defp split(tensor) do
    {x, _} = Nx.shape(tensor)
    len = trunc(x / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}
  end
end

Load the data and prepare the train batches and validation dataset.

%{
  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the model and train it

epochs = 25

model =
  Axon.input("data")
  |> Axon.dense(1200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(500)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(10, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 2250 seconds ca.

Results after 25 epochs:

  • accuracy: 0.8000934 - loss: 1.0527945
  • validation accuracy: 0.4242000 - validation loss: 5.1675453

Running on Convolutions

Prepare the data

Compared to the previous implementation, the images (inputs) are not flatten, but all the 3 dimensions (channel, height, width) are kept.

defmodule Chapter19.CIFAR10Cnn do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

    %{
      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels
    }
  end

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
    |> Nx.transpose(axes: [:n, :h, :w, :c])
    |> Nx.divide(255.0)
  end

  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  defp split(%Nx.Tensor{shape: {n, _c, _h, _w}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: :n)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: :n)
    {first_half, second_half}
  end

  defp split(%Nx.Tensor{shape: {n, _}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}
  end
end

Load the data and prepare the train batches and validation dataset.

%{
  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10Cnn.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the CNN and train it

epochs = 20

model =
  Axon.input("data", shape: {nil, 32, 32, 3})
  |> Axon.conv(16, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.conv(32, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.flatten()
  |> Axon.dense(1000, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(512, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

Axon.Display.as_table(model, Nx.to_template(validation_images)) |> IO.puts()

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 8400 seconds ca.

Results after 20 epochs:

  • accuracy: 0.8096849 - loss: 0.8324687
  • validation accuracy: 0.6632000 - validation loss: 1.5468998

Channels :first VS :last

In the Chapter19.CIFAR10Cnn the inputs extracted from the binary are transposed to have the channels axis as last:

defp transform_images({bin, type, shape}) do
  bin
  |> Nx.from_binary(type)
  |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
  |> Nx.transpose(axes: [:n, :h, :w, :c])
  |> Nx.divide(255.0)
end

The Axon.conv API already expects the channels axis to be last one, therefore there is no need to set the option explicitly.

But it is worth mentioning that keeping the channels as first axis (by skipping the transposition) is a possibility.

I tried both the approaches and having the channels axis as last led to better accuracy overall:

Channels position training accuracy validation accuracy
:first 0.7166188 0.6012000
:last 0.8096849 0.6632000

Then, to conclude the Keras implementation in the book is equivalent to the Axon implementation with channels as :last axis.

Keras model summary

The model

model = Sequential()

model.add(Conv2D(16, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Flatten())

model.add(Dense(1000, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Dense(512, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=Adam(),
              metrics=['accuracy'])
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 conv2d (Conv2D)             (50000, 30, 30, 16)       448

 batch_normalization (BatchN  (50000, 30, 30, 16)      64
 ormalization)

 dropout (Dropout)           (50000, 30, 30, 16)       0

 conv2d_1 (Conv2D)           (50000, 28, 28, 32)       4640

 batch_normalization_1 (Batc  (50000, 28, 28, 32)      128
 hNormalization)

 dropout_1 (Dropout)         (50000, 28, 28, 32)       0

 flatten (Flatten)           (50000, 25088)            0

 dense (Dense)               (50000, 1000)             25089000

 batch_normalization_2 (Batc  (50000, 1000)            4000
 hNormalization)

 dropout_2 (Dropout)         (50000, 1000)             0

 dense_1 (Dense)             (50000, 512)              512512

 batch_normalization_3 (Batc  (50000, 512)             2048
 hNormalization)

 dropout_3 (Dropout)         (50000, 512)              0

 dense_2 (Dense)             (50000, 10)               5130

=================================================================
Total params: 25,617,970
Trainable params: 25,614,850
Non-trainable params: 3,120
_________________________________________________________________
None

Axon model summary

+-----------------------------------------------------------------------------------------------------------------------------------------+
|                                                                  Model                                                                  |
+=======================================+======================+====================+=========================+===========================+
| Layer                                 | Input Shape          | Output Shape       | Options                 | Parameters                |
+=======================================+======================+====================+=========================+===========================+
| data ( input )                        | []                   | {5000, 32, 32, 3}  | shape: {nil, 32, 32, 3} |                           |
|                                       |                      |                    | optional: false         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| conv_0 ( conv["data"] )               | [{5000, 32, 32, 3}]  | {5000, 30, 30, 16} | strides: 1              | kernel: f32[3][3][3][16]  |
|                                       |                      |                    | padding: :valid         | bias: f32[16]             |
|                                       |                      |                    | input_dilation: 1       |                           |
|                                       |                      |                    | kernel_dilation: 1      |                           |
|                                       |                      |                    | feature_group_size: 1   |                           |
|                                       |                      |                    | channels: :last         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_0 ( relu["conv_0"] )             | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_0 ( batch_norm["relu_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | epsilon: 1.0e-5         | gamma: f32[16]            |
|                                       |                      |                    | channel_index: -1       | beta: f32[16]             |
|                                       |                      |                    | momentum: 0.1           | mean: f32[16]             |
|                                       |                      |                    |                         | var: f32[16]              |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_0 ( dropout["batch_norm_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | rate: 0.5               | key: f32[2]               |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| conv_1 ( conv["dropout_0"] )          | [{5000, 30, 30, 16}] | {5000, 28, 28, 32} | strides: 1              | kernel: f32[3][3][16][32] |
|                                       |                      |                    | padding: :valid         | bias: f32[32]             |
|                                       |                      |                    | input_dilation: 1       |                           |
|                                       |                      |                    | kernel_dilation: 1      |                           |
|                                       |                      |                    | feature_group_size: 1   |                           |
|                                       |                      |                    | channels: :last         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_1 ( relu["conv_1"] )             | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_1 ( batch_norm["relu_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | epsilon: 1.0e-5         | gamma: f32[32]            |
|                                       |                      |                    | channel_index: -1       | beta: f32[32]             |
|                                       |                      |                    | momentum: 0.1           | mean: f32[32]             |
|                                       |                      |                    |                         | var: f32[32]              |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_1 ( dropout["batch_norm_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | rate: 0.5               | key: f32[2]               |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| flatten_0 ( flatten["dropout_1"] )    | [{5000, 28, 28, 32}] | {5000, 25088}      |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_0 ( dense["flatten_0"] )        | [{5000, 25088}]      | {5000, 1000}       |                         | kernel: f32[25088][1000]  |
|                                       |                      |                    |                         | bias: f32[1000]           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_2 ( relu["dense_0"] )            | [{5000, 1000}]       | {5000, 1000}       |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_2 ( batch_norm["relu_2"] ) | [{5000, 1000}]       | {5000, 1000}       | epsilon: 1.0e-5         | gamma: f32[1000]          |
|                                       |                      |                    | channel_index: -1       | beta: f32[1000]           |
|                                       |                      |                    | momentum: 0.1           | mean: f32[1000]           |
|                                       |                      |                    |                         | var: f32[1000]            |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_2 ( dropout["batch_norm_2"] ) | [{5000, 1000}]       | {5000, 1000}       | rate: 0.5               | key: f32[2]               |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_1 ( dense["dropout_2"] )        | [{5000, 1000}]       | {5000, 512}        |                         | kernel: f32[1000][512]    |
|                                       |                      |                    |                         | bias: f32[512]            |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_3 ( relu["dense_1"] )            | [{5000, 512}]        | {5000, 512}        |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_3 ( batch_norm["relu_3"] ) | [{5000, 512}]        | {5000, 512}        | epsilon: 1.0e-5         | gamma: f32[512]           |
|                                       |                      |                    | channel_index: -1       | beta: f32[512]            |
|                                       |                      |                    | momentum: 0.1           | mean: f32[512]            |
|                                       |                      |                    |                         | var: f32[512]             |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_3 ( dropout["batch_norm_3"] ) | [{5000, 512}]        | {5000, 512}        | rate: 0.5               | key: f32[2]               |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_2 ( dense["dropout_3"] )        | [{5000, 512}]        | {5000, 10}         |                         | kernel: f32[512][10]      |
|                                       |                      |                    |                         | bias: f32[10]             |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| softmax_0 ( softmax["dense_2"] )      | [{5000, 10}]         | {5000, 10}         |                         |                           |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
Total Parameters: 25617978
Total Parameters Memory: 102471912 bytes