Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Training Loop with Flip Augmentation

notebooks/dev/flip_image_training.livemd

Training Loop with Flip Augmentation

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.6"},
  {:exla, "~> 0.7"},
  {:req, "~> 0.5"},
  {:scidata, "~> 0.1"},
  {:nx_image, "~> 0.1"},
  {:kino, "~> 0.13"},
  {:kino_vega_lite, "~> 0.1"},
  {:longwingex, path: "/home/ml2/code/elixir/ai/libs/longwingex"}
])

Horizontal Flip

Horizontally flips the image.

Show tensor data

Simple image visualizer for development use.

defmodule Show do
  def show_image(img, multiplier) do
    {x, y, _value} = Nx.shape(img)
    
    Nx.multiply(img, 255)
    |> Nx.as_type(:u8)
    |> NxImage.resize({x * multiplier, y * multiplier}, method: :nearest)
    |> Kino.Image.new()
  end
end

Retrieve and Prepare Data

{train_images, train_labels} = Scidata.FashionMNIST.download()
{trn_data, trn_type, shape} = train_images
train_data =
  trn_data
  |> Nx.from_binary(trn_type)
  |> Nx.reshape({:auto, 28, 28, 1})
  |> Nx.divide(255)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels

label_data =
  labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

Hyperparameters

hyperparams = %{
  epochs: 5,
  # batch_size: 16,
  batch_size: 3,
  key_init: 372,
  epoch_key_init: 876,
  dataset_size: 60_000,
  bbox_copy_hyperparams: %{
    box_percent: 0.2,
    max_boxes: 4,
    probability: 0.6
  },
  crop: %{
    padding: 1,
    probability: 0.8
  },
  flip: %{
    axis: 2,
    probability: 0.25
  }
}

Model

model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

Data Stream

batch_img_stream = 
  train_data
  |> Nx.to_batched(hyperparams.batch_size)
batch_label_stream = 
  label_data
  |> Nx.to_batched(hyperparams.batch_size)
batched_stream = Stream.zip([batch_img_stream, batch_label_stream])
orig_img_expect_output_batch = 
  batched_stream
  |> Enum.take(1)

[{orig_img_batch, _}] = orig_img_expect_output_batch

Dev

first_row =
  batched_stream
  |> Enum.take(1)

Flip Augmented Image Stream

augmented_img_stream =
  Longwingex.Augment.add_random_key_to_stream(batch_img_stream, batch_label_stream, 
    hyperparams.dataset_size, hyperparams.batch_size, hyperparams.key_init)
  |> Stream.map(fn(batch) ->
    Longwingex.Augment.Vision.horizontal_flip(batch, 
      hyperparams.flip.axis,
      hyperparams.flip.probability
    )
    |> Longwingex.Augment.remove_random_key_from_stream()
  end)
[{aug_img_batch, _}] =
  augmented_img_stream
  |> Enum.take(1)
aug_img_batch
# [{aug_img_tensor, label_tensor}]= first_row
Show.show_image(Nx.take(aug_img_batch, 0), 8)
Show.show_image(Nx.take(aug_img_batch, 1), 8)
# [first_orig_img_tensor] =
#   batch_img_stream
#   |> Enum.take(1)
Show.show_image(Nx.take(orig_img_batch, 0), 8)
Show.show_image(Nx.take(orig_img_batch, 1), 8)

Training Loop with Augmentation

trained_model_params_with_aug =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.run(augmented_img_stream, %{},
    # epochs: hyperparams[:epochs],
    epochs: 1,
    compiler: EXLA
  )

Comparison with the test data leaderboard

Now that we have the trained model parameters from the training effort, we can use them for calculating test data accuracy.

Let’s get the test data.

{test_images, test_labels} = Scidata.FashionMNIST.download_test()
{test_images_binary, test_images_type, test_images_shape} = test_images

test_batched_images =
  test_images_binary
  |> Nx.from_binary(test_images_type)
  |> Nx.reshape(test_images_shape)
  |> Nx.divide(255)
  |> Nx.to_batched(hyperparams[:batch_size])
# One-hot-encode and batch labels
{test_labels_binary, _test_labels_type, _shape} = test_labels

test_batched_labels =
  test_labels_binary
  |> Nx.from_binary(labels_type)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(hyperparams[:batch_size])

Instead of Axon.predict, we’ll use Axon.loop.evaluator with an accuracy metric.

ElixirFashionMLChallenge Leaderboard (Accuracy) on 7/30/2023

5 Epochs - 87.4%

20 Epochs - 87.7%

50 Epochs - 87.8%

Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(
  Stream.zip(test_batched_images, test_batched_labels),
  trained_model_params_with_aug,
  compiler: EXLA
)