Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

morelli

03_ai_images/nbs/train.livemd

morelli

Mix.install([
  {:bumblebee, "~> 0.3.0"},
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.1"},
  {:exla, "~> 0.5.1"},
  {:explorer, "~> 0.5.0"},
  {:image, "~> 0.28.0"}
])

Section

# defmodule DataLoader do
#   @cpus System.schedulers_online()

#   def data(%Dataset{batch_size: batch_size} = loader) do
#     loader.batches
#     |> Task.async_stream(
#       fn batch ->
#         {inp, labels} =
#           batch
#           |> Enum.map(fn img ->
#             load_image(img, loader)
#           end)
#           |> Enum.unzip()

#         {inp, labels}
#       end,
#       timeout: :infinity,
#       ordered: not loader.shuffle,
#       max_concurrent: @cpus - 1
#     )
#     |> Stream.map(fn {:ok, {inp, labels}} ->
#       inp =
#         inp
#         |> Enum.map(&image_to_nx/1)
#         |> Nx.stack()
#         |> Nx.backend_transfer({EXLA.Backend, client: :cuda})

#       labels =
#         labels
#         |> one_hot(loader.class_count)
#         |> Nx.backend_transfer({EXLA.Backend, client: :cuda})

#       {inp, labels}
#     end)
#     |> Stream.cycle()
#   end

#   defp load_image({img_path, class}, loader) do
#     {:ok, {img, _}} = Vix.Vips.Operation.jpegload(Path.join(loader.dir, img_path))

#     img = Preprocess.apply_transforms(img, loader.transforms)

#     {img, class}
#   end

#   defp image_to_nx(img) do
#     img
#     |> Image.write_to_binary()
#     |> elem(1)
#     |> Nx.from_binary({:f, 32})
#     |> Nx.reshape({224, 224, 3})
#     |> Nx.transpose(axes: [2, 0, 1])
#   end

#   def one_hot(value, count) do
#     Nx.tensor(value)
#     |> Nx.new_axis(-1)
#     |> Nx.equal(Nx.tensor(Enum.to_list(0..(count - 1))))
#   end
# end
alias Explorer.DataFrame, as: DF
alias Explorer.Series
id = "facebook/dino-vits8"

{:ok, spec} =
  Bumblebee.load_spec({:hf, id},
    architecture: :for_image_classification
  )

spec = Bumblebee.configure(spec, num_labels: 2)

{:ok, model} = Bumblebee.load_model({:hf, id}, spec: spec)
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, id})
defmodule Data do
  def get_files(path) do
    if File.dir?(path) do
      path
      |> File.ls!()
      |> Enum.map(&Path.join(path, &1))
      |> Enum.flat_map(&get_files/1)
    else
      cond do
        # svg??
        String.ends_with?(path, [".jpg", ".jpeg", ".png", ".svg"]) -> [path]
        true -> []
      end
    end
  end

  def label_fn(real_images_map, x) do
    # Base path is data/artifact-dataset/IND_DATASETS
    path = String.split(x, "/") |> Enum.at(2)
    if MapSet.member?(real_images_map, path), do: 1, else: 0
  end

  def build_df(path, real_datasets) do
    real_images = MapSet.new(real_datasets)
    paths = get_files(path)

    df =
      Explorer.DataFrame.new(%{
        paths: Series.from_list(paths)
      })

    target = Enum.map(paths, fn x -> label_fn(real_images, x) end)
    df = DF.put(df, "target", target)
    df = DF.shuffle(df)
    df
  end

  def load_image(path) do
    image = Image.new!(path)
    image |> Image.to_nx!()
  end

  # Get a stream that
  def load(path_data, path_df, featurizer, batch_size, real_datasets, opts \\ []) do
    force = Keyword.get(opts, :force, false)

    # df =
    #   if force || File.exists?(path_df) do
    #     DF.from_csv!(path_df)
    #   else
    #     build_df(path_data, real_datasets)
    #   end

    df = build_df(path_data, real_datasets)

    paths = df["paths"]
    target = df["target"]

    stream =
      paths
      |> Explorer.Series.to_enum()
      |> Stream.zip(Explorer.Series.to_enum(target))

    stream
    |> Stream.chunk_every(batch_size)
    |> Task.async_stream(fn batch ->
      {paths, targets} = Enum.unzip(batch)
      load_tasks = Enum.map(paths, fn path -> Task.async(fn -> load_image(path) end) end)
      loaded = Enum.map(load_tasks, fn task -> Task.await(task) end)
      # Can put featurizer in tasks if that could help, might need too much mem
      featurized = Bumblebee.apply_featurizer(featurizer, loaded)
      {featurized, Nx.stack(targets)}
    end)
  end
end
real_dataset = ["afhq", "celebahq", "coco", "ffhq", "imagenet", "landscape", "lsun", "metfaces"]
batch_size = 16

data = Data.load("data/artifact-dataset", "train.csv", featurizer, batch_size, real_dataset)

# train_data = Data.load("data/artifact-dataset", "df_train.csv", featurizer, batch_size, real_dataset)
# train_data = Data.load("data/artifact-dataset", "df_test.csv", featurizer, batch_size, real_dataset)
Enum.take(data, 1)
# TAKE ONE IMAGE AND VISUALIZE
%{model: model, params: params} = model

model
[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
logits_model = Axon.nx(model, & &1.logits)
loss =
  &Axon.Losses.binary_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true
  )

optimizer = Axon.Optimizers.adam(5.0e-5)

loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true)

loop = Axon.Loop.metric(loop, accuracy, "accuracy")
loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
trained_model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  # Test strict? true/false
  |> Axon.Loop.run(data, params, epochs: 1, compiler: EXLA, strict?: false)
# Did not split into train/test yet
# logits_model
# |> Axon.Loop.evaluator()
# |> Axon.Loop.metric(accuracy, "accuracy")
# |> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)