Powered by AppSignal & Oban Pro

Ch8: Transfer Learning

Ch8 - Transfer Learning.livemd

Ch8: Transfer Learning

Mix.install([
  {:axon_onnx, github: "elixir-nx/axon_onnx"},
  {:axon, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:stb_image, "~> 0.6"},
  {:kino, "~> 0.8"},
  {:table_rex, "~> 3.1.1"}
])

Nx.global_default_backend(EXLA.Backend)

Setup: Reuse pipeline from Ch7

There’s a minor alteration from the ch7 pipeline. We’ll use a channels first model and so need to transpose the image matrices.

There are also some minor adjustments to batch size and image dimensions. The latter is because the model will need larger image sizes.

defmodule CatsAndDogs do
  def pipeline(paths, batch_size, target_height, target_width, augment \\ false) do
    paths
    |> Enum.shuffle()
    |> Task.async_stream(&parse_image/1)
    |> Stream.filter(fn
      {:ok, {%StbImage{}, _}} ->
        true

      _ ->
        false
    end)
    |> Stream.map(&to_tensors(&1, target_height, target_width))
    |> augment_data(augment)
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn chunks ->
      {img_chunk, label_chunk} = Enum.unzip(chunks)
      {Nx.stack(img_chunk), Nx.stack(label_chunk)}
    end)
  end

  defp augment_data(stream, augment) do
    if augment do
      stream
      |> Stream.map(&random_flip(&1, :height))
      |> Stream.map(&random_flip(&1, :width))
    else
      stream
    end
  end

  defp to_tensors({:ok, {img, label}}, target_height, target_width) do
    img_tensor =
      img
      |> StbImage.resize(target_height, target_width)
      |> StbImage.to_nx()
      |> Nx.divide(255)
      # This is addition to pipeline to allow for channels first model
      |> Nx.transpose(axes: [:channels, :height, :width])

    label_tensor = Nx.tensor([label])
    {img_tensor, label_tensor}
  end

  defp parse_image(path) do
    filename = Path.basename(path, ".jpg")
    label = if String.contains?(filename, "cat"), do: 0, else: 1

    case StbImage.read_file(path) do
      {:ok, img} -> {img, label}
      _error -> :error
    end
  end

  defp random_flip({image, label}, axis) do
    if :rand.uniform() < 0.5 do
      {Nx.reverse(image, axes: [axis]), label}
    else
      {image, label}
    end
  end
end
base_path = "Dev/Education/Elixir/ml/"

{test_paths, train_paths} =
  (base_path <> "Datasets/dogs-vs-cats/train/*.jpg")
  |> Path.wildcard()
  |> Enum.shuffle()
  |> Enum.split(1000)

{test_paths, val_paths} = test_paths |> Enum.split(750)

batch_size = 32
target_height = 160
target_width = 160

train_pipeline =
  CatsAndDogs.pipeline(
    train_paths,
    batch_size,
    target_height,
    target_width,
    true
  )

# Notice that you don’t want to apply augmentations to your test or validation pipeline. 
# You don’t want to make classification more difficult for your model at test time.
val_pipeline =
  CatsAndDogs.pipeline(
    val_paths,
    batch_size,
    target_height,
    target_width,
    false
  )

test_pipeline =
  CatsAndDogs.pipeline(
    test_paths,
    batch_size,
    target_height,
    target_width,
    false
  )

# Enum.take(train_pipeline, 1)
# Enum.take(test_pipeline, 1)

Transfer learning

Import the model

{mn_2_7_model, mn_2_7_params} =
  AxonOnnx.import(
    base_path <> "Models/mobilenetv2-7.onnx",
    batch_size: batch_size
  )
input_template = Nx.template({1, 3, target_height, target_width}, :f32)
Axon.Display.as_graph(mn_2_7_model, input_template)

{_popped, [cnn_base]} =
  mn_2_7_model |> Axon.pop_node()

# Debugging match on pop_node
# match?(%Axon{output: id, nodes: nodes}, cnn_base)
# Map.keys(cnn_base)

{_popped, [cnn_base]} =
  cnn_base |> Axon.pop_node()

cnn_params =
  Map.drop(
    mn_2_7_params,
    [
      "mobilenetv20_output_flatten0_reshape0",
      "mobilenetv20_output_pred_fwd"
    ]
  )
# freeze deprecrated in favor of ModelState.freeze. This block is attempt to use newer route.
# Unclear to me if what is below works. Given no frozen params in struct after I think not.
# After experimentation I gave up on this route and this block is remnant.

mn_2_7_model_state =
  Axon.ModelState.new(%{})
  |> Map.put(:parameters, mn_2_7_params)

    

model_state = 
  mn_2_7_model_state
  |> Axon.ModelState.freeze(fn
    ["mobilenetv20_output" <> _, _] ->
      false

    _ ->
      true
  end)

model_state.frozen_parameters
cnn_base = cnn_base |> Axon.freeze()

The output shape of the model at this point is {batch_size, 1280, 1, 1}, so you need to flatten the features before passing them to a classification head. You can flatten the features using Axon.flatten/2, or you can use a global pooling layer. Because the amount of output features in this model is relatively large, a global pooling layer works better because it reduces the amount of input features to the classification head. Additionally, you’ll want to add some regularization by using a dropout layer between your global average pooling and classification head.

model =   
  cnn_base   
  |> Axon.global_avg_pool(channels: :first)   
  |> Axon.dropout(rate: 0.2)  
  |> Axon.dense(1)
loss =
  &amp;Axon.Losses.binary_cross_entropy(&amp;1, &amp;2,
    reduction: :mean,
    from_logits: true
  )

optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4)

trained_model_state =
  model
  |> Axon.Loop.trainer(loss, optimizer)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.validate(model, val_pipeline)
  |> Axon.Loop.early_stop("validation_loss", mode: :min, patience: 5)
  |> Axon.Loop.run(
    train_pipeline,
    cnn_params,
    epochs: 100,
    compiler: EXLA
  )
eval_model =
  model
  |> Axon.sigmoid()

eval_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, trained_model_state, compiler: EXLA)

Fine Tuning

Map.keys(mn_2_7_params) |> length()
# freeze top most layers. This is roughly half of the model's layers.
model = model |> Axon.unfreeze(up: 50)
loss =
  &amp;Axon.Losses.binary_cross_entropy(&amp;1, &amp;2,
    reduction: :mean,
    from_logits: true
  )

optimizer = Polaris.Optimizers.rmsprop(learning_rate: 1.0e-5)

trained_model_state =
  model
  |> Axon.Loop.trainer(loss, optimizer)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.validate(model, val_pipeline)
  |> Axon.Loop.early_stop("
    validation_loss",
    mode: :min,
    patience: 5
  )
  |> Axon.Loop.run(
    train_pipeline,
    trained_model_state,
    epochs: 100,
    compiler: EXLA
  )
eval_model =
  model
  |> Axon.sigmoid()

eval_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, trained_model_state, compiler: EXLA)