Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Image Accessibility

transfer-learning.livemd

Image Accessibility

Mix.install(
  [
    {:kino_bumblebee, "~> 0.5.0"},
    {:exla, ">= 0.0.0"},
    {:axon, "~> 0.7.0"},
    {:table_rex, "~> 3.1.1"},
    {:polaris, "~> 0.1.0"},
    {:evision, "~> 0.2.9"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Section

config = %{
  epochs: 500,
  batch_size: 1,
  data_root: "/home/livebook"
  #data_root: "/Users/darren/dev/11785-project"
}
defmodule ClassificationLabels do 
  @labels [
    "table",
    "other",
    "text",
    "code"
  ]
  def get(), do: @labels
end
defmodule ResNet do 
  def model_for(labels) do 
    
    id_to_label = 
      labels
      |> Enum.with_index()
      |> Enum.reduce(%{}, fn {label, idx}, m -> Map.put(m, idx, label) end)

    num_labels = Enum.count(labels)
    
    {:ok, spec} =
      Bumblebee.load_spec({:hf, "microsoft/resnet-50"},
        architecture: :for_image_classification
      )
    
    spec = Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
    {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
    
    model_info
  end
end
defmodule RawData do 

  def ensure_available(config) do 
    if File.exists?("#{config.data_root}/images") do 
      :exists
    else
      download(config) |> untar(config)
    end
  end

  defp download(config) do 
    url = "https://s3.amazonaws.com/votre-montreal.com/images.tar.gz"
    destination = "#{config.data_root}/images.tar.gz"

    #{_, exit_status} = System.cmd("curl", ["-o", destination, url])
    {_, exit_status} = System.cmd("wget", [url, "-O", destination])
    
    if exit_status == 0 do
      :ok
    else
      :error
    end
  end

  defp untar(:error, _), do: :error
  defp untar(:ok, config) do 
    {_, exit_status} = System.cmd("tar", ["-xzvf",  "#{config.data_root}/images.tar.gz"])

    if exit_status == 0 do
      :ok
    else
      :error
    end
  end

end

RawData.ensure_available(config)

defmodule ImageLoader do
  @image_size {224, 224}
  @mean [0.485, 0.456, 0.406]
  @std [0.229, 0.224, 0.225]

  def load_images_from_dir(base_dir) do
    ClassificationLabels.get()
    |> Enum.reduce({[], []}, fn label, {images_acc, labels_acc} ->
      # Get the path to the subdirectory
      label_path = Path.join(base_dir, label)
      # Get the index of the label
      label_idx = ClassificationLabels.get() |> Enum.find_index(fn l -> l == label end)

      # Find all images in the subdirectory
      images =
        label_path
        |> File.ls!()  # List all files in the directory
        |> Enum.filter(fn f -> 
          !String.starts_with?(f, ".") and
          !String.ends_with?(f, ".svg") and
          !String.ends_with?(f, ".DS_Store") and 
          !String.ends_with?(f, ".avif") and 
          !String.ends_with?(f, ".gif")
        end)
        |> Enum.map(fn image_file ->
          image_path = Path.join(label_path, image_file)
          
          # Load the image and resize to 224x224
          image = Evision.imread(image_path)
          resized_image = Evision.resize(image, @image_size)

          # Convert the image to Nx tensor and normalize it
          Nx.from_binary(Evision.Mat.to_binary(resized_image), {:u, 8})
          |> Nx.reshape({@image_size |> elem(0), @image_size |> elem(1), 3})
          |> Nx.divide(255.0)  # Scale pixel values from [0, 255] to [0, 1]
          |> normalize_image()
        end)


      # Append to the accumulators
      {images_acc ++ images, labels_acc ++ List.duplicate(label_idx, length(images))}
    end)
  end

  def train_val_split({images, labels}, train_pct \\ 0.5) do 

    count_to_take = (Enum.count(images) * train_pct) |> trunc
    
    {train, dev} = Enum.zip(images, labels)
    |> Enum.shuffle()
    |> Enum.split(count_to_take)

    {Enum.unzip(train), Enum.unzip(dev)}
  end

  # Function to normalize image based on mean and std for each channel
  defp normalize_image(image_tensor) do
    mean_tensor = Nx.tensor(@mean, backend: Nx.BinaryBackend)
    std_tensor = Nx.tensor(@std, backend: Nx.BinaryBackend)

    image_tensor
    |> Nx.subtract(mean_tensor)
    |> Nx.divide(std_tensor)
  end
  
end



defmodule Trainer do
  alias Axon.Loop

  # Training loop function
  def train_model(config, train_data, dev_data, model_info) do
    # Ensure both images and labels are Nx tensors

    model = model_info.model
    |> Axon.nx(fn %{logits: logits} -> logits end)

    # Split data into batches
    train_batches = create_batches(config, train_data)
    dev_batches = create_batches(config, dev_data)

    # The model is already loaded, we just configure the optimizer, loss, and metrics
    optimizer = Polaris.Optimizers.sgd(learning_rate: 1.0e-1)
    
    loss_fn = fn targets, preds ->
      Axon.Losses.categorical_cross_entropy(
        targets, preds, reduction: :sum, sparse: true, from_logits: true)
    end

    frozen_params = Axon.ModelState.freeze(model_info.params, fn [path, _a] ->
      path != "image_classification_head.output"
    end)

    # An eval execution, will be called after the completion of each epoch
    run_eval = fn state ->      
      
      model_state = state.step_state.model_state
      
      r = Axon.Loop.evaluator(model)
      |> Axon.Loop.metric(:accuracy, "VAL accuracy")
      |> Axon.Loop.run(dev_batches, model_state)

      IO.inspect(r)
  
      {:continue, state}
    end
    
    # Create the training loop 
    Loop.trainer(model, loss_fn, optimizer)
    |> Axon.Loop.handle_event(:epoch_completed, run_eval)
    |> Axon.Loop.metric(:accuracy, "TRAIN accuracy")
    |> Loop.run(train_batches, model_info.params, epochs: config.epochs, compiler: EXLA)

  end


  # Function to create batches from images and labels
  defp create_batches(config, {images, labels}) do
    Enum.zip(images, labels) 
    |> Enum.chunk_every(config.batch_size, config.batch_size, :discard)
    |> Enum.map(fn chunk ->
      {images, labels} = Enum.unzip(chunk)
      {Nx.stack(images), Nx.stack(labels)}
    end)
  end
end

# Load the pretrained ResNet, but for classification for our labels
model_info = ClassificationLabels.get() |> ResNet.model_for()
# = Nx.template({16, 224, 224, 3}, :f32)
#Axon.Display.as_graph(model_info.model, t)
# Load our image data sets
{train_data, dev_data} = ImageLoader.load_images_from_dir("#{config.data_root}/images")
|> ImageLoader.train_val_split()
{images, labels} = train_data
IO.inspect(Enum.count(images))

{images, labels} = dev_data
IO.inspect(Enum.count(images))
# Train the model
Trainer.train_model(config, train_data, dev_data, model_info)