Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Image Classification

transfer-learning.livemd

Image Classification

Mix.install(
  [
    {:kino_bumblebee, "~> 0.5.0"},
    {:exla, ">= 0.0.0"},
    {:axon, "~> 0.7.0"},
    {:table_rex, "~> 3.1.1"},
    {:polaris, "~> 0.1.0"},
    {:evision, "~> 0.2.9"},
    {:ex_aws, "~> 2.2"},
    {:ex_aws_s3, "~> 2.3"},
    {:hackney, "~> 1.17"},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Configuration

Here we set up some familiar configuration values to drive our implementation. The data_root directory value of /home/livebook is required to correctly map to the default volume of a remote Fly.io machine running LiveBook.

config = %{
  epochs: 50,
  batch_size: 8,
  #data_root: "/home/livebook",
  data_root: "/Users/darren/dev/11785-project",
  bucket: "siegel-xapi-dev" # A personal S3 bucket to save the model to
}

The Task At Hand

We will be fine tuning a pre-trained vision model in order to have it be able to differentiate between three classes of images:

  • “table”: images which contain screenshots of tabular information
  • “code”: images which contain screenshots of source code
  • “other”: images which are not in either of the first two classes

We will optimize a categorical cross entropy loss:

$$\mathcal{L}{\text{CCE}} = -\frac{1}{N} \sum{i=1}^{N} \sum{k=1}^{C} y{i,k} \log(\hat{y}_{i,k})$$

Where $N$ is the number of samples, $C$ is the number of classes, $y{i,k}$ is the true label for sample $i$ and class $k$ and $\hat{y}{i,k}$ is the predicted probability for sample $i$ and class $k$.

This model will then be incorporated into Torus and fronted by an course author facing feature: When a user uploads an image to use in their course learning material, model inference determines if the image is likely a screenshot of tabular data or source code. If so, the system will warn the author that this poses a web accessibilty problem and they should instead embed the content as source code or as an HTML table.

defmodule ClassificationLabels do
  @labels [
    "table",
    "other",
    "code"
  ]
  def get(), do: @labels
end

Here we have a module to fetch the pretrained ResNet model from Hugging Face and reconfigure it for a three class classification task.

defmodule ResNet do
  def model_for(labels) do
    id_to_label =
      labels
      |> Enum.with_index()
      |> Enum.reduce(%{}, fn {label, idx}, m -> Map.put(m, idx, label) end)

    num_labels = Enum.count(labels)

    {:ok, spec} =
      Bumblebee.load_spec({:hf, "microsoft/resnet-50"},
        architecture: :for_image_classification
      )

    spec = Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
    {:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
    IO.inspect(model_info)
    model_info
  end
end

Raw Data

I handcrafted an image dataset consisting of around 130 examples for each of the three classes. I archived that dataset in an S3 bucket. This module allows the downloading and untaring of the image dataset.

defmodule RawData do
  def ensure_available(config) do
    if File.exists?("#{config.data_root}/images") do
      :exists
    else
      download(config) |> untar(config)
    end
  end

  defp download(config) do
    url = "https://s3.amazonaws.com/votre-montreal.com/images.tar.gz"
    destination = "#{config.data_root}/images.tar.gz"

    # {_, exit_status} = System.cmd("curl", ["-o", destination, url])

    {_, exit_status} = System.cmd("wget", [url, "-O", destination])

    if exit_status == 0 do
      :ok
    else
      IO.inspect(exit_status)
      :error
    end
  end

  defp untar(:error, _), do: :error

  defp untar(:ok, config) do
    {_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/images.tar.gz"])

    if exit_status == 0 do
      :ok
    else
      IO.inspect(exit_status)
      :error
    end
  end
end

RawData.ensure_available(config)

Image Loader

The ImageLoader module reads all labeled image files from disk and can split them into train and validation data sets.

The Evision library is used to load and resize the images, but basic Nx commands are used to rescale and normalize the images in the data sets.

defmodule ImageLoader do
  @image_size {224, 224}
  @mean [0.485, 0.456, 0.406]
  @std [0.229, 0.224, 0.225]

  defstruct [
    :images,
    :labels,
    :paths
  ]

  def new(base_dir) do 

    {images, labels, paths} = load_images_from_dir(base_dir)

    %ImageLoader{
      images: images,
      labels: labels,
      paths: paths
    }
  end

  def load_images_from_dir(base_dir) do
    ClassificationLabels.get()
    |> Enum.reduce({[], [], []}, fn label, {images_acc, labels_acc, paths_acc} ->
      # Get the path to the subdirectory
      label_path = Path.join(base_dir, label)
      # Get the index of the label
      label_idx = ClassificationLabels.get() |> Enum.find_index(fn l -> l == label end)

      # Find all images in the subdirectory
      {images, paths} =
        label_path
        # List all files in the directory
        |> File.ls!()
        |> Enum.filter(fn f ->
          !String.starts_with?(f, ".") and
            !String.ends_with?(f, ".svg") and
            !String.ends_with?(f, ".DS_Store") and
            !String.ends_with?(f, ".avif") and
            !String.ends_with?(f, ".gif")
        end)
        |> Enum.map(fn image_file ->
          image_path = Path.join(label_path, image_file)

          # Load the image and resize to 224x224
          {image_from_path(image_path), image_path}
        end)
        |> Enum.unzip()

      # Append to the accumulators
      {
        images_acc ++ images, 
        labels_acc ++ List.duplicate(label_idx, length(images)),
        paths_acc ++ paths
      }
    end)
  end

  def image_from_path(path) do
    image = Evision.imread(path)
    resized_image = Evision.resize(image, @image_size)

    # Convert the image to Nx tensor and normalize it
    Nx.from_binary(Evision.Mat.to_binary(resized_image), {:u, 8})
    |> Nx.reshape({@image_size |> elem(0), @image_size |> elem(1), 3})
    # Scale pixel values from [0, 255] to [0, 1]
    |> Nx.divide(255.0)
    |> normalize_image()
  end

  def train_val_split(%ImageLoader{images: images, labels: labels}, train_pct \\ 0.9) do

    count_to_take = (Enum.count(images) * train_pct) |> trunc

    {train, dev} =
      Enum.zip(images, labels)
      |> Enum.shuffle()
      |> Enum.split(count_to_take)

    {Enum.unzip(train), Enum.unzip(dev)}
  end

  # Function to normalize image based on mean and std for each channel
  def normalize_image(image_tensor) do    
    mean_tensor = Nx.tensor(@mean, backend: Nx.BinaryBackend)
    std_tensor = Nx.tensor(@std, backend: Nx.BinaryBackend)

    image_tensor
    |> Nx.subtract(mean_tensor)
    |> Nx.divide(std_tensor)
  end
end

image_loader = ImageLoader.new("#{config.data_root}/images")

render_image = fn path -> 
  content = File.read!(path)
  Kino.Image.new(content, "image/jpeg")
  |> Kino.render()
end

Kino.Text.new("Table Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 0))
Kino.Text.new("Other Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 150))
Kino.Text.new("Source Code Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 270))

:ok
defmodule Trainer do
  alias Axon.Loop

  # Training loop function
  def train_model(config, train_data, dev_data, model_info) do
    # Ensure both images and labels are Nx tensors

    model =
      model_info.model
      |> Axon.nx(fn %{logits: logits} -> logits end)

    # Split data into batches
    train_batches = create_batches(config, train_data)
    dev_batches = create_batches(config, dev_data)

    # The model is already loaded, we just configure the optimizer, loss, and metrics
    optimizer = Polaris.Optimizers.sgd(learning_rate: 1.0e-2)

    loss_fn = fn targets, preds ->
      Axon.Losses.categorical_cross_entropy(
        targets,
        preds,
        reduction: :sum,
        sparse: true,
        from_logits: true
      )
    end

    frozen_params =
      Axon.ModelState.freeze(model_info.params, fn [path, _a] ->
        path != "image_classification_head.output"
      end)

    # An eval execution, will be called after the completion of each epoch
    run_eval = fn state ->
      model_state = state.step_state.model_state

      r =
        Axon.Loop.evaluator(model)
        |> Axon.Loop.metric(:accuracy, "VAL accuracy")
        |> Axon.Loop.run(dev_batches, model_state)

      IO.inspect(r)

      {:continue, state}
    end

    # Create the training loop 
    Loop.trainer(model, loss_fn, optimizer)
    |> Axon.Loop.handle_event(:epoch_completed, run_eval)
    |> Axon.Loop.metric(:accuracy, "TRAIN accuracy")
    |> Loop.run(train_batches, model_info.params, epochs: config.epochs, compiler: EXLA)
  end

  # Function to create batches from images and labels
  defp create_batches(config, {images, labels}) do
    Enum.zip(images, labels)
    |> Enum.chunk_every(config.batch_size, config.batch_size, :discard)
    |> Enum.map(fn chunk ->
      {images, labels} = Enum.unzip(chunk)
      {Nx.stack(images), Nx.stack(labels)}
    end)
  end
end
# Load the pretrained ResNet, but for classification for our labels
model_info = ClassificationLabels.get() |> ResNet.model_for()
t = Nx.template({16, 224, 224, 3}, :f32)
# Axon.Display.as_graph(model_info.model, t)
# Load our image data sets
{train_data, dev_data} = ImageLoader.train_val_split(image_loader)

{images, labels} = train_data
IO.inspect("Number of training images")
IO.inspect(Enum.count(images))

{images, labels} = dev_data
IO.inspect("Number of validation images")
IO.inspect(Enum.count(images))

:ok

Training Loop

This initiates the training loop.

I ran this numerous times with different learning rates, different number of epochs, and with and without freezing the weights of the underlying ResNet.

The best training run (without freezing the params, SGD at 1e-3) yielded train of 74% and validation of 76%. I stopped it after 11 epochs as I feared it was about to start to overfit. I saved this model and pushed it to S3 with the ModelPersistence model below.

# Train the model
Trainer.train_model(config, train_data, dev_data, model_info)
Epoch: 11, Batch: 0, TRAIN accuracy: 0.74030461932262534 
Batch: 4, VAL accuracy: 0.76150209234934002
# Save the model to disk, but most importantly to an S3 bucket

defmodule ModelPersistence do
  def save(model_info, bucket) do
    params = model_info.params

    serialized_params = Nx.serialize(params, [])
    File.write!("model.axon", serialized_params)

    # Reload it from the file
    read_file = File.read!("model.axon")

    # Fetch the AWS key and secret
    aws_access_key_id = System.fetch_env!("LB_AWS_ACCESS_KEY_ID")
    aws_secret_access_key = System.fetch_env!("LB_AWS_SECRET_ACCESS_KEY")

    # Configure ExAws
    aws_config = %{
      access_key_id: aws_access_key_id,
      secret_access_key: aws_secret_access_key,
      # Change the region to your preference
      region: "us-east-1"
    }

    # Set the configuration for ExAws
    Application.put_env(:ex_aws, :access_key_id, aws_config[:access_key_id])
    Application.put_env(:ex_aws, :secret_access_key, aws_config[:secret_access_key])
    Application.put_env(:ex_aws, :region, aws_config[:region])

    ExAws.S3.put_object(bucket, "model.axon", read_file)
    |> ExAws.request()
  end
end

ModelPersistence.save(model_info, config.bucket)

Interactive Usage

Here we can easily upload an image and test out our model’s classification.

image_input = Kino.Input.image("Image", size: {224, 224}, format: :jpg)
form = Kino.Control.form([image: image_input], submit: "Run")
frame = Kino.Frame.new()

params = model_info.params
model = model_info.model

Kino.listen(form, fn %{data: %{image: image}} ->
  if image do
    Kino.Frame.render(frame, Kino.Text.new("Running..."))

    batched_image =
      image.file_ref
      |> Kino.Input.file_path()
      |> ImageLoader.image_from_path()
      |> Nx.new_axis(0)

    logits = Axon.predict(model, params, batched_image).logits

    softmax = fn t ->
      exp_tensor = Nx.exp(t)
      sum_exp = Nx.sum(exp_tensor, axes: [-1], keep_axes: true)
      Nx.divide(exp_tensor, sum_exp)
    end

    # Apply softmax to logits
    probabilities = softmax.(logits) |> Nx.to_flat_list()

    items = Enum.zip(probabilities, ClassificationLabels.get())

    IO.inspect(items)
    # Kino.HTML.new("
    #{items}
")
end end) Kino.Layout.grid([form, frame], boxed: true, gap: 16)

Results / Commentary

Pros:

  1. Bumblebee is fantastic. It provides an easy (and interactive) way to access, use and explore pre-trained vision, text and other models.

Cons:

  1. There was barely enough documentation and examples online that showed how one can use transfer learning and repurpose an existing model like ResNet to a different task. I spent a significant amount of time trying to get the loss function working, finally realizing that sparse: true and from_logits: true has to be specified.
loss_fn = fn targets, preds ->
      Axon.Losses.categorical_cross_entropy(
        targets,
        preds,
        reduction: :sum,
        sparse: true,
        from_logits: true
      )
    end
  1. I couldn’t find a way to simply mount my Google Drive, so instead I had to rely on uploading the fine tuned weights to an S3 bucket. This was only a minor inconvienance.

Takeaways

While this work proved to be successful and I was indeed able to fine tune ResNet for a transfer learning task, I cannot image doing this again in Elixir. There is just such a massively broader variety of examples and documentation available for PyTorch for this type of task, that it is ridiculous to spend time trying to figure things out in Elixir.