Chapter 19: beyond vanilla networks
Mix.install(
[
{:exla, "~> 0.5"},
{:nx, "~> 0.5"},
{:axon, "~> 0.5"},
{:kino, "~> 0.8.1"},
{:kino_vega_lite, "~> 0.1.7"},
{:vega_lite, "~> 0.1.6"},
{:scidata, "~> 0.1"},
{:nx_image, "~> 0.1.0"},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
The CIFAR-10 Dataset
What CIFAR-10 looks like
{
{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}
} = Scidata.CIFAR10.download()
images =
images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape({elem(images_shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
labels =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
columns = 8
rows = 4
key = Nx.Random.key(42)
# Compute random indices
indices =
{elem(images_shape, 0) - 1}
|> Nx.iota()
|> then(fn data ->
{shuffled_data, _new_key} = Nx.Random.shuffle(key, data)
shuffled_data
end)
|> Nx.slice_along_axis(0, columns * rows)
selected_images = Nx.take(images, indices)
selected_labels = Nx.take(labels, indices)
Kino.Layout.grid(
Enum.map(0..(columns * rows - 1), fn i ->
Kino.Layout.grid(
[
Kino.Markdown.new("class: #{selected_labels[i][0] |> Nx.to_number()}"),
selected_images[i]
# transpose the image since `Kino.Image.new` requires
# the following shape `{:h, :w, :c}`, while the original one is `{:c, :h, :w}`
|> Nx.transpose(axes: [:h, :w, :c])
|> NxImage.resize({100, 100}, method: :nearest)
|> Kino.Image.new()
],
boxed: true,
columns: 1
)
end),
boxed: true,
columns: columns
)
Falling short of CIFAR
Prepare the data
defmodule Chapter19.CIFAR10 do
def load_data() do
{raw_images, raw_labels} = Scidata.CIFAR10.download()
{raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()
train_images = transform_images(raw_images)
train_labels = transform_labels(raw_labels)
all_test_images = transform_images(raw_test_images)
all_test_labels = transform_labels(raw_test_labels)
{validation_images, test_images} = split(all_test_images)
{validation_labels, test_labels} = split(all_test_labels)
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
}
end
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), :auto})
|> Nx.divide(255.0)
end
defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
defp split(tensor) do
{x, _} = Nx.shape(tensor)
len = trunc(x / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
{first_half, second_half}
end
end
Load the data and prepare the train batches and validation dataset.
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
} = Chapter19.CIFAR10.load_data()
train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]
Build the model and train it
epochs = 25
model =
Axon.input("data")
|> Axon.dense(1200)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(500)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(200)
|> Axon.relu()
|> Axon.batch_norm()
|> Axon.dense(10, activation: :softmax)
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)
Training completed in 2250 seconds ca.
Results after 25 epochs:
- accuracy: 0.8000934 - loss: 1.0527945
- validation accuracy: 0.4242000 - validation loss: 5.1675453
Running on Convolutions
Prepare the data
Compared to the previous implementation, the images (inputs) are not flatten, but all the 3 dimensions (channel, height, width) are kept.
defmodule Chapter19.CIFAR10Cnn do
def load_data() do
{raw_images, raw_labels} = Scidata.CIFAR10.download()
{raw_test_images, raw_test_labels} = Scidata.CIFAR10.download_test()
train_images = transform_images(raw_images)
train_labels = transform_labels(raw_labels)
all_test_images = transform_images(raw_test_images)
all_test_labels = transform_labels(raw_test_labels)
{validation_images, test_images} = split(all_test_images)
{validation_labels, test_labels} = split(all_test_labels)
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
}
end
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
|> Nx.transpose(axes: [:n, :h, :w, :c])
|> Nx.divide(255.0)
end
defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
end
defp split(%Nx.Tensor{shape: {n, _c, _h, _w}} = tensor) do
len = trunc(n / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: :n)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: :n)
{first_half, second_half}
end
defp split(%Nx.Tensor{shape: {n, _}} = tensor) do
len = trunc(n / 2)
first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
{first_half, second_half}
end
end
Load the data and prepare the train batches and validation dataset.
%{
train_images: train_images,
train_labels: train_labels,
validation_images: validation_images,
validation_labels: validation_labels,
test_images: test_images,
test_labels: test_labels
} = Chapter19.CIFAR10Cnn.load_data()
train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]
Build the CNN and train it
epochs = 20
model =
Axon.input("data", shape: {nil, 32, 32, 3})
|> Axon.conv(16, kernel_size: 3, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.conv(32, kernel_size: 3, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.flatten()
|> Axon.dense(1000, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(512, activation: :relu)
|> Axon.batch_norm()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)
Axon.Display.as_table(model, Nx.to_template(validation_images)) |> IO.puts()
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)
Training completed in 8400 seconds ca.
Results after 20 epochs:
- accuracy: 0.8096849 - loss: 0.8324687
- validation accuracy: 0.6632000 - validation loss: 1.5468998
Channels :first
VS :last
In the Chapter19.CIFAR10Cnn
the inputs extracted from the binary are transposed to have the channels
axis as last:
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 3, 32, 32}, names: [:n, :c, :h, :w])
|> Nx.transpose(axes: [:n, :h, :w, :c])
|> Nx.divide(255.0)
end
The Axon.conv
API already expects the channels
axis to be last one, therefore there is no need to set the option explicitly.
But it is worth mentioning that keeping the channels
as first axis (by skipping the transposition) is a possibility.
I tried both the approaches and having the channels
axis as last led to better accuracy overall:
Channels position | training accuracy | validation accuracy |
---|---|---|
:first | 0.7166188 | 0.6012000 |
:last | 0.8096849 | 0.6632000 |
Then, to conclude the Keras implementation in the book is equivalent to the Axon implementation with channels as :last
axis.
Keras model summary
The model
model = Sequential()
model.add(Conv2D(16, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(512, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (50000, 30, 30, 16) 448
batch_normalization (BatchN (50000, 30, 30, 16) 64
ormalization)
dropout (Dropout) (50000, 30, 30, 16) 0
conv2d_1 (Conv2D) (50000, 28, 28, 32) 4640
batch_normalization_1 (Batc (50000, 28, 28, 32) 128
hNormalization)
dropout_1 (Dropout) (50000, 28, 28, 32) 0
flatten (Flatten) (50000, 25088) 0
dense (Dense) (50000, 1000) 25089000
batch_normalization_2 (Batc (50000, 1000) 4000
hNormalization)
dropout_2 (Dropout) (50000, 1000) 0
dense_1 (Dense) (50000, 512) 512512
batch_normalization_3 (Batc (50000, 512) 2048
hNormalization)
dropout_3 (Dropout) (50000, 512) 0
dense_2 (Dense) (50000, 10) 5130
=================================================================
Total params: 25,617,970
Trainable params: 25,614,850
Non-trainable params: 3,120
_________________________________________________________________
None
Axon model summary
+-----------------------------------------------------------------------------------------------------------------------------------------+
| Model |
+=======================================+======================+====================+=========================+===========================+
| Layer | Input Shape | Output Shape | Options | Parameters |
+=======================================+======================+====================+=========================+===========================+
| data ( input ) | [] | {5000, 32, 32, 3} | shape: {nil, 32, 32, 3} | |
| | | | optional: false | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| conv_0 ( conv["data"] ) | [{5000, 32, 32, 3}] | {5000, 30, 30, 16} | strides: 1 | kernel: f32[3][3][3][16] |
| | | | padding: :valid | bias: f32[16] |
| | | | input_dilation: 1 | |
| | | | kernel_dilation: 1 | |
| | | | feature_group_size: 1 | |
| | | | channels: :last | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_0 ( relu["conv_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_0 ( batch_norm["relu_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | epsilon: 1.0e-5 | gamma: f32[16] |
| | | | channel_index: -1 | beta: f32[16] |
| | | | momentum: 0.1 | mean: f32[16] |
| | | | | var: f32[16] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_0 ( dropout["batch_norm_0"] ) | [{5000, 30, 30, 16}] | {5000, 30, 30, 16} | rate: 0.5 | key: f32[2] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| conv_1 ( conv["dropout_0"] ) | [{5000, 30, 30, 16}] | {5000, 28, 28, 32} | strides: 1 | kernel: f32[3][3][16][32] |
| | | | padding: :valid | bias: f32[32] |
| | | | input_dilation: 1 | |
| | | | kernel_dilation: 1 | |
| | | | feature_group_size: 1 | |
| | | | channels: :last | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_1 ( relu["conv_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_1 ( batch_norm["relu_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | epsilon: 1.0e-5 | gamma: f32[32] |
| | | | channel_index: -1 | beta: f32[32] |
| | | | momentum: 0.1 | mean: f32[32] |
| | | | | var: f32[32] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_1 ( dropout["batch_norm_1"] ) | [{5000, 28, 28, 32}] | {5000, 28, 28, 32} | rate: 0.5 | key: f32[2] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| flatten_0 ( flatten["dropout_1"] ) | [{5000, 28, 28, 32}] | {5000, 25088} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_0 ( dense["flatten_0"] ) | [{5000, 25088}] | {5000, 1000} | | kernel: f32[25088][1000] |
| | | | | bias: f32[1000] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_2 ( relu["dense_0"] ) | [{5000, 1000}] | {5000, 1000} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_2 ( batch_norm["relu_2"] ) | [{5000, 1000}] | {5000, 1000} | epsilon: 1.0e-5 | gamma: f32[1000] |
| | | | channel_index: -1 | beta: f32[1000] |
| | | | momentum: 0.1 | mean: f32[1000] |
| | | | | var: f32[1000] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_2 ( dropout["batch_norm_2"] ) | [{5000, 1000}] | {5000, 1000} | rate: 0.5 | key: f32[2] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_1 ( dense["dropout_2"] ) | [{5000, 1000}] | {5000, 512} | | kernel: f32[1000][512] |
| | | | | bias: f32[512] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| relu_3 ( relu["dense_1"] ) | [{5000, 512}] | {5000, 512} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| batch_norm_3 ( batch_norm["relu_3"] ) | [{5000, 512}] | {5000, 512} | epsilon: 1.0e-5 | gamma: f32[512] |
| | | | channel_index: -1 | beta: f32[512] |
| | | | momentum: 0.1 | mean: f32[512] |
| | | | | var: f32[512] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dropout_3 ( dropout["batch_norm_3"] ) | [{5000, 512}] | {5000, 512} | rate: 0.5 | key: f32[2] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| dense_2 ( dense["dropout_3"] ) | [{5000, 512}] | {5000, 10} | | kernel: f32[512][10] |
| | | | | bias: f32[10] |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
| softmax_0 ( softmax["dense_2"] ) | [{5000, 10}] | {5000, 10} | | |
+---------------------------------------+----------------------+--------------------+-------------------------+---------------------------+
Total Parameters: 25617978
Total Parameters Memory: 102471912 bytes