Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapt 7 notebook - Torchx fork

chapt7-torchx.livemd

Chapt 7 notebook - Torchx fork

Mix.install(
  [
    # {:benchee, "~> 1.3"},
    # {:explorer, "~> 0.9.1"},
    {:stb_image, "~> 0.6.9"},
    {:axon, "~> 0.6.1"},
    # {:exla, "~> 0.7.3"},
    {:torchx, "~> 0.7"},
    {:nx, "~> 0.7.3"},
    {:scidata, "~> 0.1.11"},
    {:kino, "~> 0.14.0"},
    {:table_rex, "~> 3.0"}
  ],
  config: [
    nx: [
      default_backend: {Torchx.Backend, device: :cuda}
    ]
  ],
  system_env: %{"LIBTORCH_TARGET" => "cu121", "LIBTORCH_VERSION" => "2.4.1"}
)

Identitfying Cats and Dogs

defmodule CatsAndDogs do
  def pipeline(paths, batch_size, target_height, target_width) do
    paths
    |> Enum.shuffle()
    |> Task.async_stream(&parse_image/1)
    |> Stream.filter(fn 
      {:ok, {%StbImage{}, _}} -> true
      _ -> false
        end)
    |> Stream.map(&to_tensors(&1, target_height, target_width))
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn chunks ->
      {img_chunk, label_chunk} = Enum.unzip(chunks)
      {Nx.stack(img_chunk), Nx.stack(label_chunk)}
    end)
  end

  def pipeline_with_aug(paths, batch_size, target_height, target_width) do
    paths
    |> Enum.shuffle()
    |> Task.async_stream(&parse_image/1)
    |> Stream.filter(fn 
      {:ok, {%StbImage{}, _}} -> true
      _ -> false
        end)
    |> Stream.map(&to_tensors(&1, target_height, target_width))
    |> Stream.map(&random_flip(&1, :height))
    |> Stream.map(&random_flip(&1, :width))
    |> Stream.chunk_every(batch_size, batch_size, :discard)
    |> Stream.map(fn chunks ->
      {img_chunk, label_chunk} = Enum.unzip(chunks)
      {Nx.stack(img_chunk), Nx.stack(label_chunk)}
    end)
  end

  defp parse_image(path) do
    label = if String.contains?(path, "cat"), do: 0, else: 1

    case StbImage.read_file(path) do
      {:ok, img} -> {img, label}
      _err -> :error
    end
  end

  defp to_tensors({:ok, {img, label}}, target_height, target_width) do
    img_tensor = img
    |> StbImage.resize(target_height, target_width)
    |> StbImage.to_nx()
    |> Nx.divide(255)

    label_tensor = Nx.tensor([label])
    
    {img_tensor, label_tensor}
  end

  defp random_flip({image, label}, axis) do
    if :rand.uniform() < 0.5 do
      {Nx.reverse(image, axes: [axis]), label}
    else
      {image, label}
    end
  end
end

Data Setup

{:ok, cwd} = File.cwd()
train_path = "#{cwd}/files/train/*.jpg"
{test_paths, train_paths} = train_path
  |> Path.wildcard()
  |> Enum.shuffle()
  |> Enum.split(1000)

batch_size = 128
target_height = target_width = 96
train_pipeline = CatsAndDogs.pipeline(train_paths, batch_size, target_height, target_width)
test_pipeline = CatsAndDogs.pipeline(test_paths, batch_size, target_height, target_width)

Enum.take(train_pipeline, 1)

MLP Model

mlp_model = "images"
|> Axon.input(shape: {nil, target_height, target_width, 3})
|> Axon.flatten()
|> Axon.dense(256, activation: :relu)
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)

mlp_trained_model_stat = mlp_model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_pipeline, %{}, epochs: 5)

CNN Model

path = String.replace(train_path, "*", "dog.5")
img = path
|> StbImage.read_file!()
|> StbImage.to_nx()
|> Nx.transpose(axes: [:channels, :height, :width])
|> Nx.new_axis(0)

# edge detector kernel
kernel = Nx.tensor([
  [-1, 0, 1],
  [-1, 0, 1],
  [-1, 0, 1]
])
|> Nx.reshape({1, 1, 3, 3})
|> Nx.broadcast({3, 3, 3, 3})

img
|> Nx.conv(kernel)
|> Nx.as_type({:u, 8})
|> Nx.squeeze(axes: [0])
|> Nx.transpose(axes: [:height, :width, :channels])
|> Kino.Image.new()
cnn_model = Axon.input("images", shape: {nil, 96, 96, 3})
|> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(64, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(128, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
template = Nx.template({1, 96, 96, 3}, :f32)
Axon.Display.as_graph(cnn_model, template)
cnn_trained_model_state = cnn_model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_pipeline, %{}, epochs: 5)
cnn_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, cnn_trained_model_state)

Augmenting Data


{:ok, cwd} = File.cwd() 
train_path = "#{cwd}/files/train/*.jpg"
{test_paths, train_paths} = train_path
  |> Path.wildcard()
  |> Enum.shuffle()
  |> Enum.split(1000)
{test_paths, val_paths} = Enum.split(test_paths, 750)

batch_size = 128
target_height = target_width = 96
train_pipeline = CatsAndDogs.pipeline_with_aug(train_paths, batch_size, target_height, target_width)
val_pipeline = CatsAndDogs.pipeline(val_paths, batch_size, target_height, target_width)
test_pipeline = CatsAndDogs.pipeline(test_paths, batch_size, target_height, target_width)

Enum.take(train_pipeline, 1)
cnn_model = Axon.input("images", shape: {nil, 96, 96, 3})
|> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(64, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.conv(128, kernel_size: {3, 3}, padding: :same, activation: :relu)
|> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
|> Axon.flatten()
|> Axon.dense(128, activation: :relu)
|> Axon.dropout(rate: 0.5)
|> Axon.dense(1, activation: :sigmoid)
template = Nx.template({1, 96, 96, 3}, :f32)
Axon.Display.as_table(cnn_model, template) |> IO.puts()
cnn_trained_model_state = cnn_model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(cnn_model, val_pipeline)
|> Axon.Loop.early_stop("validation_loss", mode: :min)
|> Axon.Loop.run(train_pipeline, %{}, epochs: 100)
cnn_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_pipeline, cnn_trained_model_state)