Training Loop with Bounding Box Augmentation
Mix.install([
{:nx, "~> 0.7"},
{:axon, "~> 0.6"},
{:exla, "~> 0.7"},
{:req, "~> 0.5"},
{:scidata, "~> 0.1"},
{:nx_image, "~> 0.1"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"},
{:longwingex, path: "/home/ml2/code/elixir/ai/libs/longwingex"}
])
Bounding Box Copy
Copies random number of bounding boxes of pixels from one location to another location in the image. The number of boxes are determined by max_boxes, integer. The size of the boxes are determined by a box_percent of the image size, 0.0 to 1.0. The probability the augmentation will be used is determined by the probability, float: 0.0 to 1.0.
Show tensor data
Simple image visualizer for development use.
defmodule Show do
def show_image(img, multiplier) do
{x, y, _value} = Nx.shape(img)
Nx.multiply(img, 255)
|> Nx.as_type(:u8)
|> NxImage.resize({x * multiplier, y * multiplier}, method: :nearest)
|> Kino.Image.new()
end
end
Retrieve and Prepare Data
{train_images, train_labels} = Scidata.FashionMNIST.download()
{trn_data, trn_type, shape} = train_images
train_data =
trn_data
|> Nx.from_binary(trn_type)
|> Nx.reshape({:auto, 28, 28, 1})
|> Nx.divide(255)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels
label_data =
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
Hyperparameters
hyperparams = %{
epochs: 5,
batch_size: 16,
# batch_size: 3,
key_init: 243,
dataset_size: 60_000,
bbox_copy_hyperparams: %{
box_percent: 0.2,
max_boxes: 4,
probability: 0.6
}
}
Model
model =
Axon.input("input", shape: {nil, 1, 28, 28})
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
Data Stream
batch_img_stream =
train_data
|> Nx.to_batched(hyperparams.batch_size)
batch_label_stream =
label_data
|> Nx.to_batched(hyperparams.batch_size)
batched_stream = Stream.zip([batch_img_stream, batch_label_stream])
batched_stream
|> Enum.take(2)
Bounding Box Copy Augmented Image Stream
augmented_img_stream =
Longwingex.Augment.add_random_key_to_stream(batch_img_stream, batch_label_stream,
hyperparams.dataset_size, hyperparams.batch_size, hyperparams.key_init)
|> Stream.map(fn(batch) ->
Longwingex.Augment.Vision.copy_bboxes(batch,
hyperparams.bbox_copy_hyperparams.box_percent,
hyperparams.bbox_copy_hyperparams.max_boxes,
hyperparams.bbox_copy_hyperparams.probability
)
|> Longwingex.Augment.remove_random_key_from_stream()
end)
first_row =
augmented_img_stream
|> Enum.take(1)
[{aug_img_tensor, label_tensor}]= first_row
Show.show_image(Nx.take(aug_img_tensor, 0), 8)
Show.show_image(Nx.take(aug_img_tensor, 1), 8)
Show.show_image(Nx.take(aug_img_tensor, 2), 8)
Show.show_image(Nx.take(aug_img_tensor, 3), 8)
[first_orig_img_tensor] =
batch_img_stream
|> Enum.take(1)
Show.show_image(Nx.take(first_orig_img_tensor, 0), 8)
Show.show_image(Nx.take(first_orig_img_tensor, 1), 8)
Show.show_image(Nx.take(first_orig_img_tensor, 2), 8)
Training Loop with Augmentation
trained_model_params_with_aug =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(augmented_img_stream, %{},
epochs: hyperparams[:epochs],
compiler: EXLA
)
Comparison with the test data leaderboard
Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.
Let’s get the test data.
{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, test_images_type, test_images_shape} = test_images
test_batched_images =
test_images_binary
|> Nx.from_binary(test_images_type)
|> Nx.reshape(test_images_shape)
|> Nx.divide(255)
|> Nx.to_batched(hyperparams[:batch_size])
# One-hot-encode and batch labels
{test_labels_binary, _test_labels_type, _shape} = test_labels
test_batched_labels =
test_labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched(hyperparams[:batch_size])
Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.
ElixirFashionMLChallenge Leaderboard (Accuracy) on 7/30/2023
5 Epochs - 87.4%
20 Epochs - 87.7%
50 Epochs - 87.8%
Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(
Stream.zip(test_batched_images, test_batched_labels),
trained_model_params_with_aug,
compiler: EXLA
)