Chapter 19: beyond vanilla networks
Mix.install(
[
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:axon, "~> 0.5"},
{:kino, "~> 0.8.1"},
{:kino_vega_lite, "~> 0.1.7"},
{:vega_lite, "~> 0.1.6"},
{:scidata, "~> 0.1"},
{:nx_image, "~> 0.1.0"},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
The CIFAR-10 Dataset
What CIFAR-10 looks like
{
{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}
} = Scidata.CIFAR10.download()
images =
images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape({elem(images_shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
labels =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
columns = 8
rows = 4
key = Nx.Random.key(42)
# Compute random indices
indices =
{elem(images_shape, 0) - 1}
|> Nx.iota()
|> then(fn data ->
{shuffled_data, _new_key} = Nx.Random.shuffle(key, data)
shuffled_data
end)
|> Nx.slice_along_axis(0, columns * rows)
selected_images = Nx.take(images, indices)
selected_labels = Nx.take(labels, indices)
Kino.Layout.grid(
Enum.map(0..(columns * rows - 1), fn i ->
Kino.Layout.grid(
[
Kino.Markdown.new("class: #{selected_labels[i][0] |> Nx.to_number()}"),
selected_images[i]
# transpose the image since `Kino.Image.new` requires
# the following shape `{:h, :w, :c}`, while the original one is `{:c, :h, :w}`
|> Nx.transpose(axes: [:h, :w, :c])
|> NxImage.resize({100, 100}, method: :nearest)
|> Kino.Image.new()
],
boxed: true,
columns: 1
)
end),
boxed: true,
columns: columns
)
Falling short of CIFAR
Prepare the data
defmodule Chapter19.CIFAR10 do
def load_data() do
{raw_images, raw_labels} = Scidata.CIFAR10.download()
{raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()
train_images = transform_images(raw_images)
train_labels = transform_labels(raw_labels)
all_test_images = transform_images(raw_test_images)
all_test_labels = transform_labels(raw_test_labels)
{validation_images, test_images} = split(all_test_images)
{validation_labels, test_labels} = split(all_test_labels)
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
}
end
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), :auto})
|> Nx.divide(255.0)
end
defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
defp split(tensor) do
{x, _} = Nx.shape(tensor)
len = trunc(x / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
{first_half, second_half}
end
end
Load the data and prepare the train batches and validation dataset.
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
} = Chapter19.CIFAR10.load_data()
train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]
Build the model and train it
epochs = 25
model =
Axon.input("data")
|> Axon.dense(1200)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(500)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(200)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(10, activation: :softmax)
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)
Training completed in 2250 seconds ca.
Results after 25 epochs:
- accuracy: 0.8000934 - loss: 1.0527945
- validation accuracy: 0.4242000 - validation loss: 5.1675453
Running on Convolutions
Prepare the data
Compared to the previous implementation, the images (inputs) are not flatten, but all the 3 dimensions (channel, height, width) are kept.
defmodule Chapter19.CIFAR10Cnn do
def load_data() do
{raw_images, raw_labels} = Scidata.CIFAR10.download()
{raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()
train_images = transform_images(raw_images)
train_labels = transform_labels(raw_labels)
all_test_images = transform_images(raw_test_images)
all_test_labels = transform_labels(raw_test_labels)
{validation_images, test_images} = split(all_test_images)
{validation_labels, test_labels} = split(all_test_labels)
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
}
end
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
|> Nx.transpose(axes: [:n, :h, :w, :c])
|> Nx.divide(255.0)
end
defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
defp split(%Nx.Tensor{shape: {n, _c, _h, _w}} = tensor) do
len = trunc(n / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: :n)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: :n)
{first_half, second_half}
end
defp split(%Nx.Tensor{shape: {n, _}} = tensor) do
len = trunc(n / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
{first_half, second_half}
end
end
Load the data and prepare the train batches and validation dataset.
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
} = Chapter19.CIFAR10Cnn.load_data()
train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]
Build the CNN and train it
epochs = 20
model =
Axon.input("data", shape: {nil, 32, 32, 3})
|> Axon.conv(16, kernel_size: 3, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.conv(32, kernel_size: 3, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.flatten()
|> Axon.dense(1000, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(512, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)
Axon.Display.as_table(model, Nx.to_template(validation_images)) |> IO.puts()
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)
Training completed in 8400 seconds ca.
Results after 20 epochs:
- accuracy: 0.8096849 - loss: 0.8324687
- validation accuracy: 0.6632000 - validation loss: 1.5468998
Channels :first VS :last
In the Chapter19.CIFAR10Cnn the inputs extracted from the binary are transposed to have the channels axis as last:
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
|> Nx.transpose(axes: [:n, :h, :w, :c])
|> Nx.divide(255.0)
end
The Axon.conv API already expects the channels axis to be last one, therefore there is no need to set the option explicitly.
But it is worth mentioning that keeping the channels as first axis (by skipping the transposition) is a possibility.
I tried both the approaches and having the channels axis as last led to better accuracy overall:
| Channels position | training accuracy | validation accuracy |
|---|---|---|
| :first | 0.7166188 | 0.6012000 |
| :last | 0.8096849 | 0.6632000 |
Then, to conclude the Keras implementation in the book is equivalent to the Axon implementation with channels as :last axis.
Axon model summary ``` +-----------------------------------------------------------------------------------------------------------------------------------------+ | Model | +=======================================+======================+====================+=========================+===========================+ | Layer | Input Shape | Output Shape | Options | Parameters | +=======================================+======================+====================+=========================+===========================+ | data ( input ) | [] | {5000, 32, 32, 3} | shape: {nil, 32, 32, 3} | | | | | | optional: false | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | conv_0 ( conv["data"] ) | [{5000, 32, 32, 3}] | {5000, 30, 30, 16} | strides: 1 | kernel: f32[3][3][3][16] | | | | | padding: :valid | bias: f32[16] | | | | | input_dilation: 1 | | | | | | kernel_dilation: 1 | | | | | | feature_group_size: 1 | | | | | | channels: :last | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_0 ( relu["conv_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_0 ( batch_norm["relu_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | epsilon: 1.0e-5 | gamma: f32[16] | | | | | channel_index: -1 | beta: f32[16] | | | | | momentum: 0.1 | mean: f32[16] | | | | | | var: f32[16] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_0 ( dropout["batch_norm_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | conv_1 ( conv["dropout_0"] ) | [{5000, 30, 30, 16}] | {5000, 28, 28, 32} | strides: 1 | kernel: f32[3][3][16][32] | | | | | padding: :valid | bias: f32[32] | | | | | input_dilation: 1 | | | | | | kernel_dilation: 1 | | | | | | feature_group_size: 1 | | | | | | channels: :last | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_1 ( relu["conv_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_1 ( batch_norm["relu_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | epsilon: 1.0e-5 | gamma: f32[32] | | | | | channel_index: -1 | beta: f32[32] | | | | | momentum: 0.1 | mean: f32[32] | | | | | | var: f32[32] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_1 ( dropout["batch_norm_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | flatten_0 ( flatten["dropout_1"] ) | [{5000, 28, 28, 32}] | {5000, 25088} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_0 ( dense["flatten_0"] ) | [{5000, 25088}] | {5000, 1000} | | kernel: f32[25088][1000] | | | | | | bias: f32[1000] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_2 ( relu["dense_0"] ) | [{5000, 1000}] | {5000, 1000} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_2 ( batch_norm["relu_2"] ) | [{5000, 1000}] | {5000, 1000} | epsilon: 1.0e-5 | gamma: f32[1000] | | | | | channel_index: -1 | beta: f32[1000] | | | | | momentum: 0.1 | mean: f32[1000] | | | | | | var: f32[1000] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_2 ( dropout["batch_norm_2"] ) | [{5000, 1000}] | {5000, 1000} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_1 ( dense["dropout_2"] ) | [{5000, 1000}] | {5000, 512} | | kernel: f32[1000][512] | | | | | | bias: f32[512] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | relu_3 ( relu["dense_1"] ) | [{5000, 512}] | {5000, 512} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | batch_norm_3 ( batch_norm["relu_3"] ) | [{5000, 512}] | {5000, 512} | epsilon: 1.0e-5 | gamma: f32[512] | | | | | channel_index: -1 | beta: f32[512] | | | | | momentum: 0.1 | mean: f32[512] | | | | | | var: f32[512] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dropout_3 ( dropout["batch_norm_3"] ) | [{5000, 512}] | {5000, 512} | rate: 0.5 | key: f32[2] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | dense_2 ( dense["dropout_3"] ) | [{5000, 512}] | {5000, 10} | | kernel: f32[512][10] | | | | | | bias: f32[10] | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ | softmax_0 ( softmax["dense_2"] ) | [{5000, 10}] | {5000, 10} | | | +---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+ Total Parameters: 25617978 Total Parameters Memory: 102471912 bytes ```