Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

chatper 08

chatper08.livemd

chatper 08

Mix.install([
  {:axon, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:stb_image, "~> 0.6"},
  {:kino, "~> 0.8"},
  {:axon_onnx, github: "elixir-nx/axon_onnx"}
])

Nx.global_default_backend(EXLA.Backend)

Section

defmodule CatsAndDogs do
  def pipeline(paths, batch_size, target_height, target_width) do
    paths
    |> Enum.shuffle()
    |> Task.async_stream(&parse_image/1)
    |> Stream.filter(fn
      {:ok, {%StbImage{}, _}} -> true
      _ -> false end)
    |> Stream.map(&to_tensors(&1, target_height, target_width)) 
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn chunks ->
      {img_chunk, label_chunk} = Enum.unzip(chunks)
      {Nx.stack(img_chunk), Nx.stack(label_chunk)}
      end)
  end

  def pipeline_with_augmentations(
    paths,
    batch_size,
    target_height,
    target_width
  ) do paths
    |> Enum.shuffle()
    |> Task.async_stream(&parse_image/1) |> Stream.filter(fn
      {:ok, {%StbImage{}, _}} -> true
        _ -> false end)
    |> Stream.map(&to_tensors(&1, target_height, target_width))
    |> Stream.map(&random_flip(&1, :height))
    |> Stream.map(&random_flip(&1, :width))
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn chunks ->
      {img_chunk, label_chunk} = Enum.unzip(chunks)
      {Nx.stack(img_chunk), Nx.stack(label_chunk)}
    end)
  end

  defp random_flip({image, label}, axis) do
    if :rand.uniform() < 0.5 do
      {Nx.reverse(image, axes: [axis]), label}
    else
      {image, label}
    end
  end

  defp parse_image(path) do
    label = if String.contains?(path, "cat"), do: 0, else: 1
    case StbImage.read_file(path) do 
      {:ok, img} -> {img, label}
      _error -> :error
    end 
  end

  defp to_tensors({:ok, {img, label}}, target_height, target_width) do
    img_tensor =
    img
    |> StbImage.resize(target_height, target_width)
    |> StbImage.to_nx()
    |> Nx.divide(255)
    |> Nx.transpose(axes: [:channels, :height, :width])
        label_tensor = Nx.tensor([label])
        {img_tensor, label_tensor}
  end
end
{test_paths, train_paths} = 
  Path.wildcard("train/*.jpg")
  |> Enum.shuffle()
  |> Enum.split(1000)

{test_paths, val_paths} = test_paths |> Enum.split(750)

batch_size = 32
target_height = 160
target_width = 160

train_pipeline = CatsAndDogs.pipeline_with_augmentations(
  train_paths,
  batch_size,
  target_height,
  target_width
)
val_pipeline = CatsAndDogs.pipeline(
  val_paths,
  batch_size,
  target_height,
  target_width
)
test_pipeline = CatsAndDogs.pipeline(
  test_paths,
  batch_size,
  target_height,
  target_width
)

Enum.take(train_pipeline, 1)
{cnn_base, cnn_base_params} = AxonOnnx.import(
  "mobilenetv2-7.onnx", batch_size: batch_size
)
input_tempalte = Nx.template({1, 3, target_height, target_width}, :f32)
Axon.Display.as_graph(cnn_base, input_tempalte)
{_popped, cnn_base} = cnn_base |> Axon.pop_node()
{_popped, cnn_base} = cnn_base |> Axon.pop_node()

input_tempalte = Nx.template({1, 3, target_height, target_width}, :f32)
Axon.Display.as_graph(cnn_base, input_tempalte)
cnn_base = cnn_base |> Axon.namespace("feature_extractor")
cnn_base = cnn_base |> Axon.freeze()
{cnn_base, cnn_base_params} = AxonOnnx.import("mobilenetv2-7.onnx") 
  
{_popped, cnn_base} = Axon.pop_node(cnn_base)
{_popped, cnn_base} = Axon.pop_node(cnn_base)

model =
  cnn_base
|> Axon.namespace("feature_extractor")
|> Axon.freeze()
|> Axon.global_avg_pool(channels: :first) |> Axon.dropout(rate: 0.2)
|> Axon.dense(1)

loss = &amp;Axon.Losses.binary_cross_entropy(&amp;1, &amp;2, reduction: :mean,
from_logits: true
)

optimizer = Axon.Optimizers.adam(1.0e-4)

trained_model_state =
  model
|> Axon.Loop.trainer(loss, optimizer)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, val_pipeline)
|> Axon.Loop.early_stop("validation_loss", mode: :min, patience: 5)
|> Axon.Loop.run(
  train_pipeline,
  %{"feature_extractor" => cnn_base_params}, epochs: 100,
  compiler: EXLA
)
eval_model = model |> Axon.sigmoid()

eval_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, trained_model_state, compiler: EXLA)
model = model |> Axon.unfreeze(up: 50)
loss = &amp;Axon.Losses.binary_cross_entropy(&amp;1, &amp;2,
  reduction: :mean,
  from_logits: true
)

optimizer = Axon.Optimizers.rmsprop(1.0e-5)

trained_model_state =
  model
  |> Axon.Loop.trainer(loss, optimizer)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.validate(model, val_pipeline)
  |> Axon.Loop.early_stop("validation_loss", mode: :min, patience: 5)
  |> Axon.Loop.run(
    train_pipeline,
    trained_model_state,
    epochs: 100,
    compiler: EXLA
  )
eval_model = model |> Axon.sigmoid()

eval_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, trained_model_state, compoler: EXLA)