Powered by AppSignal & Oban Pro

Chapter 19: beyond vanilla networks

19_beyond/beyond_vanilla_networks.livemd

Chapter 19: beyond vanilla networks

Mix.install(
  [
    {:exla, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:axon, "~> 0.5"},
    {:kino, "~> 0.8.1"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:vega_lite, "~> 0.1.6"},
    {:scidata, "~> 0.1"},
    {:nx_image, "~> 0.1.0"},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

The CIFAR-10 Dataset

What CIFAR-10 looks like

{
  {images_binary, images_type, images_shape},
  {labels_binary, labels_type, labels_shape}
} = Scidata.CIFAR10.download()
images =
  images_binary
  |> Nx.from_binary(images_type)
  |> Nx.reshape({elem(images_shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])

labels =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)

columns = 8
rows = 4

key = Nx.Random.key(42)

# Compute random indices
indices =
  {elem(images_shape, 0) - 1}
  |> Nx.iota()
  |> then(fn data ->
    {shuffled_data, _new_key} = Nx.Random.shuffle(key, data)
    shuffled_data
  end)
  |> Nx.slice_along_axis(0, columns * rows)

selected_images = Nx.take(images, indices)
selected_labels = Nx.take(labels, indices)

Kino.Layout.grid(
  Enum.map(0..(columns * rows - 1), fn i ->
    Kino.Layout.grid(
      [
        Kino.Markdown.new("class: #{selected_labels[i][0] |> Nx.to_number()}"),
        selected_images[i]
        # transpose the image since `Kino.Image.new` requires
        # the following shape `{:h, :w, :c}`, while the original one is `{:c, :h, :w}`
        |> Nx.transpose(axes: [:h, :w, :c])
        |> NxImage.resize({100, 100}, method: :nearest)
        |> Kino.Image.new()
      ],
      boxed: true,
      columns: 1
    )
  end),
  boxed: true,
  columns: columns
)

Falling short of CIFAR

Prepare the data

defmodule Chapter19.CIFAR10 do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

    %{
      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels
    }
  end

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), :auto})
    |> Nx.divide(255.0)
  end

  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  defp split(tensor) do
    {x, _} = Nx.shape(tensor)
    len = trunc(x / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}
  end
end

Load the data and prepare the train batches and validation dataset.

%{
  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the model and train it

epochs = 25

model =
  Axon.input("data")
  |> Axon.dense(1200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(500)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(200)
  |> Axon.relu()
  |> Axon.batch_norm()
  |> Axon.dense(10, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 2250 seconds ca.

Results after 25 epochs:

  • accuracy: 0.8000934 - loss: 1.0527945
  • validation accuracy: 0.4242000 - validation loss: 5.1675453

Running on Convolutions

Prepare the data

Compared to the previous implementation, the images (inputs) are not flatten, but all the 3 dimensions (channel, height, width) are kept.

defmodule Chapter19.CIFAR10Cnn do
  def load_data() do
    {raw_images, raw_labels} = Scidata.CIFAR10.download()
    {raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

    %{
      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels
    }
  end

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
    |> Nx.transpose(axes: [:n, :h, :w, :c])
    |> Nx.divide(255.0)
  end

  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  defp split(%Nx.Tensor{shape: {n, _c, _h, _w}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: :n)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: :n)
    {first_half, second_half}
  end

  defp split(%Nx.Tensor{shape: {n, _}} = tensor) do
    len = trunc(n / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}
  end
end

Load the data and prepare the train batches and validation dataset.

%{
  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter19.CIFAR10Cnn.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build the CNN and train it

epochs = 20

model =
  Axon.input("data", shape: {nil, 32, 32, 3})
  |> Axon.conv(16, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.conv(32, kernel_size: 3, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.flatten()
  |> Axon.dense(1000, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(512, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

Axon.Display.as_table(model, Nx.to_template(validation_images)) |> IO.puts()

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Training completed in 8400 seconds ca.

Results after 20 epochs:

  • accuracy: 0.8096849 - loss: 0.8324687
  • validation accuracy: 0.6632000 - validation loss: 1.5468998

Channels :first VS :last

In the Chapter19.CIFAR10Cnn the inputs extracted from the binary are transposed to have the channels axis as last:

defp transform_images({bin, type, shape}) do
  bin
  |> Nx.from_binary(type)
  |> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
  |> Nx.transpose(axes: [:n, :h, :w, :c])
  |> Nx.divide(255.0)
end

The Axon.conv API already expects the channels axis to be last one, therefore there is no need to set the option explicitly.

But it is worth mentioning that keeping the channels as first axis (by skipping the transposition) is a possibility.

I tried both the approaches and having the channels axis as last led to better accuracy overall:

Channels position training accuracy validation accuracy
:first 0.7166188 0.6012000
:last 0.8096849 0.6632000

Then, to conclude the Keras implementation in the book is equivalent to the Axon implementation with channels as :last axis.

Keras model summary The model ```python model = Sequential() model.add(Conv2D(16, (3, 3), activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.5)) model.add(Conv2D(32, (3, 3), activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.5)) model.add(Flatten()) model.add(Dense(1000, activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.5)) model.add(Dense(512, activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) ``` ``` Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (50000, 30, 30, 16) 448 batch_normalization (BatchN (50000, 30, 30, 16) 64 ormalization) dropout (Dropout) (50000, 30, 30, 16) 0 conv2d_1 (Conv2D) (50000, 28, 28, 32) 4640 batch_normalization_1 (Batc (50000, 28, 28, 32) 128 hNormalization) dropout_1 (Dropout) (50000, 28, 28, 32) 0 flatten (Flatten) (50000, 25088) 0 dense (Dense) (50000, 1000) 25089000 batch_normalization_2 (Batc (50000, 1000) 4000 hNormalization) dropout_2 (Dropout) (50000, 1000) 0 dense_1 (Dense) (50000, 512) 512512 batch_normalization_3 (Batc (50000, 512) 2048 hNormalization) dropout_3 (Dropout) (50000, 512) 0 dense_2 (Dense) (50000, 10) 5130 ================================================================= Total params: 25,617,970 Trainable params: 25,614,850 Non-trainable params: 3,120 _________________________________________________________________ None ```
Axon model summary ``` +-----------------------------------------------------------------------------------------------------------------------------------------+ | Model | +=======================================+======================+====================+=========================+===========================+ | Layer | Input Shape | Output Shape | Options | Parameters | +=======================================+======================+====================+=========================+===========================+ | data ( input ) | [] | {5000, 32, 32, 3} | shape: {nil, 32, 32, 3} | | | | | | optional: false | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | conv_0 ( conv["data"] ) | [{5000, 32, 32, 3}] | {5000, 30, 30, 16} | strides: 1 | kernel: f32[3][3][3][16] | | | | | padding: :valid | bias: f32[16] | | | | | input_dilation: 1 | | | | | | kernel_dilation: 1 | | | | | | feature_group_size: 1 | | | | | | channels: :last | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_0 ( relu["conv_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_0 ( batch_norm["relu_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | epsilon: 1.0e-5 | gamma: f32[16] | | | | | channel_index: -1 | beta: f32[16] | | | | | momentum: 0.1 | mean: f32[16] | | | | | | var: f32[16] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_0 ( dropout["batch_norm_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | conv_1 ( conv["dropout_0"] ) | [{5000, 30, 30, 16}] | {5000, 28, 28, 32} | strides: 1 | kernel: f32[3][3][16][32] | | | | | padding: :valid | bias: f32[32] | | | | | input_dilation: 1 | | | | | | kernel_dilation: 1 | | | | | | feature_group_size: 1 | | | | | | channels: :last | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_1 ( relu["conv_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_1 ( batch_norm["relu_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | epsilon: 1.0e-5 | gamma: f32[32] | | | | | channel_index: -1 | beta: f32[32] | | | | | momentum: 0.1 | mean: f32[32] | | | | | | var: f32[32] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_1 ( dropout["batch_norm_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | flatten_0 ( flatten["dropout_1"] ) | [{5000, 28, 28, 32}] | {5000, 25088} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_0 ( dense["flatten_0"] ) | [{5000, 25088}] | {5000, 1000} | | kernel: f32[25088][1000] | | | | | | bias: f32[1000] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_2 ( relu["dense_0"] ) | [{5000, 1000}] | {5000, 1000} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_2 ( batch_norm["relu_2"] ) | [{5000, 1000}] | {5000, 1000} | epsilon: 1.0e-5 | gamma: f32[1000] | | | | | channel_index: -1 | beta: f32[1000] | | | | | momentum: 0.1 | mean: f32[1000] | | | | | | var: f32[1000] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_2 ( dropout["batch_norm_2"] ) | [{5000, 1000}] | {5000, 1000} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_1 ( dense["dropout_2"] ) | [{5000, 1000}] | {5000, 512} | | kernel: f32[1000][512] | | | | | | bias: f32[512] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_3 ( relu["dense_1"] ) | [{5000, 512}] | {5000, 512} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_3 ( batch_norm["relu_3"] ) | [{5000, 512}] | {5000, 512} | epsilon: 1.0e-5 | gamma: f32[512] | | | | | channel_index: -1 | beta: f32[512] | | | | | momentum: 0.1 | mean: f32[512] | | | | | | var: f32[512] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_3 ( dropout["batch_norm_3"] ) | [{5000, 512}] | {5000, 512} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_2 ( dense["dropout_3"] ) | [{5000, 512}] | {5000, 10} | | kernel: f32[512][10] | | | | | | bias: f32[10] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | softmax_0 ( softmax["dense_2"] ) | [{5000, 10}] | {5000, 10} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ Total Parameters: 25617978 Total Parameters Memory: 102471912 bytes ```