Training Loop with Flip Augmentation
Mix.install([
{:nx, "~> 0.7"},
{:axon, "~> 0.6"},
{:exla, "~> 0.7"},
{:req, "~> 0.5"},
{:scidata, "~> 0.1"},
{:nx_image, "~> 0.1"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"},
{:longwingex, path: "/home/ml2/code/elixir/ai/libs/longwingex"}
])
Horizontal Flip
Horizontally flips the image.
Show tensor data
Simple image visualizer for development use.
defmodule Show do
def show_image(img, multiplier) do
{x, y, _value} = Nx.shape(img)
Nx.multiply(img, 255)
|> Nx.as_type(:u8)
|> NxImage.resize({x * multiplier, y * multiplier}, method: :nearest)
|> Kino.Image.new()
end
end
Retrieve and Prepare Data
{train_images, train_labels} = Scidata.FashionMNIST.download()
{trn_data, trn_type, shape} = train_images
train_data =
trn_data
|> Nx.from_binary(trn_type)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels
label_data =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
Hyperparameters
hyperparams = %{
epochs: 5,
# batch_size: 16,
batch_size: 3,
key_init: 372,
epoch_key_init: 876,
dataset_size: 60_000,
bbox_copy_hyperparams: %{
box_percent: 0.2,
max_boxes: 4,
probability: 0.6
},
crop: %{
padding: 1,
probability: 0.8
},
flip: %{
axis: 2,
probability: 0.25
}
}
Model
model =
Axon.input("input", shape: {nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
Data Stream
batch_img_stream =
train_data
|> Nx.to_batched(hyperparams.batch_size)
batch_label_stream =
label_data
|> Nx.to_batched(hyperparams.batch_size)
batched_stream = Stream.zip([batch_img_stream, batch_label_stream])
orig_img_expect_output_batch =
batched_stream
|> Enum.take(1)
[{orig_img_batch, _}] = orig_img_expect_output_batch
Dev
first_row =
batched_stream
|> Enum.take(1)
Flip Augmented Image Stream
augmented_img_stream =
Longwingex.Augment.add_random_key_to_stream(batch_img_stream, batch_label_stream,
hyperparams.dataset_size, hyperparams.batch_size, hyperparams.key_init)
|> Stream.map(fn(batch) ->
Longwingex.Augment.Vision.horizontal_flip(batch,
hyperparams.flip.axis,
hyperparams.flip.probability
)
|> Longwingex.Augment.remove_random_key_from_stream()
end)
[{aug_img_batch, _}] =
augmented_img_stream
|> Enum.take(1)
aug_img_batch
# [{aug_img_tensor, label_tensor}]= first_row
Show.show_image(Nx.take(aug_img_batch, 0), 8)
Show.show_image(Nx.take(aug_img_batch, 1), 8)
# [first_orig_img_tensor] =
# batch_img_stream
# |> Enum.take(1)
Show.show_image(Nx.take(orig_img_batch, 0), 8)
Show.show_image(Nx.take(orig_img_batch, 1), 8)
Training Loop with Augmentation
trained_model_params_with_aug =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(augmented_img_stream, %{},
# epochs: hyperparams[:epochs],
epochs: 1,
compiler: EXLA
)
Comparison with the test data leaderboard
Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.
Let’s get the test data.
{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, test_images_type, test_images_shape} = test_images
test_batched_images =
test_images_binary
|> Nx.from_binary(test_images_type)
|> Nx.reshape(test_images_shape)
|> Nx.divide(255)
|> Nx.to_batched(hyperparams[:batch_size])
# One-hot-encode and batch labels
{test_labels_binary, _test_labels_type, _shape} = test_labels
test_batched_labels =
test_labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched(hyperparams[:batch_size])
Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.
ElixirFashionMLChallenge Leaderboard (Accuracy) on 7/30/2023
5 Epochs - 87.4%
20 Epochs - 87.7%
50 Epochs - 87.8%
Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(
Stream.zip(test_batched_images, test_batched_labels),
trained_model_params_with_aug,
compiler: EXLA
)