Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 18 - Hands on: the 10 Epochs Challenge

18_taming/ten_epochs_challenge.livemd

Chapter 18 - Hands on: the 10 Epochs Challenge

Mix.install(
  [
    {:exla, "~> 0.5"},
    {:nx, "~> 0.5"},
    {:axon, "~> 0.5"},
    {:kino, "~> 0.8.1"},
    {:kino_vega_lite, "~> 0.1.7"},
    {:vega_lite, "~> 0.1.6"},
    {:scidata, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Prepare the date

defmodule Chapter18.MNIST do
  def load_data() do
    {raw_images, raw_labels} = Scidata.MNIST.download()
    {raw_test_images, raw_test_labels} = Scidata.MNIST.download_test()

    train_images = transform_images(raw_images)
    train_labels = transform_labels(raw_labels)
    all_test_images = transform_images(raw_test_images)
    all_test_labels = transform_labels(raw_test_labels)

    {validation_images, test_images} = split(all_test_images)
    {validation_labels, test_labels} = split(all_test_labels)

    %{
      train_images: train_images,
      train_labels: train_labels,
      validation_images: validation_images,
      validation_labels: validation_labels,
      test_images: test_images,
      test_labels: test_labels
    }
  end

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), :auto})
    |> Nx.divide(255.0)
  end

  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  end

  defp split(tensor) do
    {x, _} = Nx.shape(tensor)
    len = trunc(x / 2)
    first_half = Nx.slice_along_axis(tensor, 0, len, axis: 0)
    second_half = Nx.slice_along_axis(tensor, len + 1, len, axis: 0)
    {first_half, second_half}
  end
end
%{
  train_images: train_images,
  train_labels: train_labels,
  validation_images: validation_images,
  validation_labels: validation_labels,
  test_images: test_images,
  test_labels: test_labels
} = Chapter18.MNIST.load_data()

train_batches = Stream.zip(Nx.to_batched(train_images, 32), Nx.to_batched(train_labels, 32))
validation_data = [{validation_images, validation_labels}]

Build and train the basic model

Initial model in Keras

model = Sequential()
model.add(Dense(1200, activation='sigmoid'))
model.add(Dense(500, activation='sigmoid'))
model.add(Dense(200, activation='sigmoid'))
model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=SGD(lr=0.1),
              metrics=['accuracy'])

history = model.fit(X_train, Y_train,
                    validation_data=(X_validation, Y_validation),
                    epochs=10, batch_size=32)
epochs = 10

model =
  Axon.input("data")
  |> Axon.dense(1200, activation: :sigmoid)
  |> Axon.dense(500, activation: :sigmoid)
  |> Axon.dense(200, activation: :sigmoid)
  |> Axon.dense(10, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.sgd(0.1))
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

👆 With the “basic” model:

  • accuracy: 0.9384413; loss: 0.3782525
  • val. accuracy: 0.9130000; val.loss: 0.2655237

It took ~450 seconds to train this model for 10 epochs

Build and train the optimized model

epochs = 5

model =
  Axon.input("data")
  |> Axon.dense(1200)
  |> Axon.leaky_relu(alpha: 0.2)
  |> Axon.batch_norm()
  |> Axon.dense(500)
  |> Axon.leaky_relu(alpha: 0.2)
  |> Axon.batch_norm()
  |> Axon.dense(200)
  |> Axon.leaky_relu(alpha: 0.2)
  |> Axon.batch_norm()
  |> Axon.dense(10, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam())
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.early_stop("validation_accuracy")
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

👆 With the “optimized” model we get better results already after 5 epochs:

  • accuracy: 0.9847380; loss: 0.1109478
  • val. accuracy: 0.9386000; val.loss: 0.2847975