Image Classification
Mix.install(
[
{:kino_bumblebee, "~> 0.5.0"},
{:exla, ">= 0.0.0"},
{:axon, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"},
{:polaris, "~> 0.1.0"},
{:evision, "~> 0.2.9"},
{:ex_aws, "~> 2.2"},
{:ex_aws_s3, "~> 2.3"},
{:hackney, "~> 1.17"},
{:table_rex, "~> 3.1.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Configuration
Here we set up some familiar configuration values to drive our implementation. The data_root
directory value of /home/livebook
is required to correctly map to the default
volume of a remote Fly.io machine running LiveBook.
config = %{
epochs: 50,
batch_size: 8,
#data_root: "/home/livebook",
data_root: "/Users/darren/dev/11785-project",
bucket: "siegel-xapi-dev" # A personal S3 bucket to save the model to
}
The Task At Hand
We will be fine tuning a pre-trained vision model in order to have it be able to differentiate between three classes of images:
- “table”: images which contain screenshots of tabular information
- “code”: images which contain screenshots of source code
- “other”: images which are not in either of the first two classes
We will optimize a categorical cross entropy loss:
$$\mathcal{L}{\text{CCE}} = -\frac{1}{N} \sum{i=1}^{N} \sum{k=1}^{C} y{i,k} \log(\hat{y}_{i,k})$$
Where $N$ is the number of samples, $C$ is the number of classes, $y{i,k}$ is the true label for sample $i$ and class $k$ and $\hat{y}{i,k}$ is the predicted probability for sample $i$ and class $k$.
This model will then be incorporated into Torus
and fronted by an course author facing
feature: When a user uploads an image to use in their course learning material, model
inference determines if the image is likely a screenshot of tabular data or source code. If
so, the system will warn the author that this poses a web accessibilty problem and they
should instead embed the content as source code or as an HTML table.
defmodule ClassificationLabels do
@labels [
"table",
"other",
"code"
]
def get(), do: @labels
end
Here we have a module to fetch the pretrained ResNet model from Hugging Face and reconfigure it for a three class classification task.
defmodule ResNet do
def model_for(labels) do
id_to_label =
labels
|> Enum.with_index()
|> Enum.reduce(%{}, fn {label, idx}, m -> Map.put(m, idx, label) end)
num_labels = Enum.count(labels)
{:ok, spec} =
Bumblebee.load_spec({:hf, "microsoft/resnet-50"},
architecture: :for_image_classification
)
spec = Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
{:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
IO.inspect(model_info)
model_info
end
end
Raw Data
I handcrafted an image dataset consisting of around 130 examples for each of the three classes. I archived that dataset in an S3 bucket. This module allows the downloading and untaring of the image dataset.
defmodule RawData do
def ensure_available(config) do
if File.exists?("#{config.data_root}/images") do
:exists
else
download(config) |> untar(config)
end
end
defp download(config) do
url = "https://s3.amazonaws.com/votre-montreal.com/images.tar.gz"
destination = "#{config.data_root}/images.tar.gz"
# {_, exit_status} = System.cmd("curl", ["-o", destination, url])
{_, exit_status} = System.cmd("wget", [url, "-O", destination])
if exit_status == 0 do
:ok
else
IO.inspect(exit_status)
:error
end
end
defp untar(:error, _), do: :error
defp untar(:ok, config) do
{_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/images.tar.gz"])
if exit_status == 0 do
:ok
else
IO.inspect(exit_status)
:error
end
end
end
RawData.ensure_available(config)
Image Loader
The ImageLoader
module reads all labeled image files from disk and can split them
into train and validation data sets.
The Evision
library is used to load and resize the images, but basic Nx
commands
are used to rescale and normalize the images in the data sets.
defmodule ImageLoader do
@image_size {224, 224}
@mean [0.485, 0.456, 0.406]
@std [0.229, 0.224, 0.225]
defstruct [
:images,
:labels,
:paths
]
def new(base_dir) do
{images, labels, paths} = load_images_from_dir(base_dir)
%ImageLoader{
images: images,
labels: labels,
paths: paths
}
end
def load_images_from_dir(base_dir) do
ClassificationLabels.get()
|> Enum.reduce({[], [], []}, fn label, {images_acc, labels_acc, paths_acc} ->
# Get the path to the subdirectory
label_path = Path.join(base_dir, label)
# Get the index of the label
label_idx = ClassificationLabels.get() |> Enum.find_index(fn l -> l == label end)
# Find all images in the subdirectory
{images, paths} =
label_path
# List all files in the directory
|> File.ls!()
|> Enum.filter(fn f ->
!String.starts_with?(f, ".") and
!String.ends_with?(f, ".svg") and
!String.ends_with?(f, ".DS_Store") and
!String.ends_with?(f, ".avif") and
!String.ends_with?(f, ".gif")
end)
|> Enum.map(fn image_file ->
image_path = Path.join(label_path, image_file)
# Load the image and resize to 224x224
{image_from_path(image_path), image_path}
end)
|> Enum.unzip()
# Append to the accumulators
{
images_acc ++ images,
labels_acc ++ List.duplicate(label_idx, length(images)),
paths_acc ++ paths
}
end)
end
def image_from_path(path) do
image = Evision.imread(path)
resized_image = Evision.resize(image, @image_size)
# Convert the image to Nx tensor and normalize it
Nx.from_binary(Evision.Mat.to_binary(resized_image), {:u, 8})
|> Nx.reshape({@image_size |> elem(0), @image_size |> elem(1), 3})
# Scale pixel values from [0, 255] to [0, 1]
|> Nx.divide(255.0)
|> normalize_image()
end
def train_val_split(%ImageLoader{images: images, labels: labels}, train_pct \\ 0.9) do
count_to_take = (Enum.count(images) * train_pct) |> trunc
{train, dev} =
Enum.zip(images, labels)
|> Enum.shuffle()
|> Enum.split(count_to_take)
{Enum.unzip(train), Enum.unzip(dev)}
end
# Function to normalize image based on mean and std for each channel
def normalize_image(image_tensor) do
mean_tensor = Nx.tensor(@mean, backend: Nx.BinaryBackend)
std_tensor = Nx.tensor(@std, backend: Nx.BinaryBackend)
image_tensor
|> Nx.subtract(mean_tensor)
|> Nx.divide(std_tensor)
end
end
image_loader = ImageLoader.new("#{config.data_root}/images")
render_image = fn path ->
content = File.read!(path)
Kino.Image.new(content, "image/jpeg")
|> Kino.render()
end
Kino.Text.new("Table Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 0))
Kino.Text.new("Other Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 150))
Kino.Text.new("Source Code Image") |> Kino.render()
render_image.(Enum.at(image_loader.paths, 270))
:ok
defmodule Trainer do
alias Axon.Loop
# Training loop function
def train_model(config, train_data, dev_data, model_info) do
# Ensure both images and labels are Nx tensors
model =
model_info.model
|> Axon.nx(fn %{logits: logits} -> logits end)
# Split data into batches
train_batches = create_batches(config, train_data)
dev_batches = create_batches(config, dev_data)
# The model is already loaded, we just configure the optimizer, loss, and metrics
optimizer = Polaris.Optimizers.sgd(learning_rate: 1.0e-2)
loss_fn = fn targets, preds ->
Axon.Losses.categorical_cross_entropy(
targets,
preds,
reduction: :sum,
sparse: true,
from_logits: true
)
end
frozen_params =
Axon.ModelState.freeze(model_info.params, fn [path, _a] ->
path != "image_classification_head.output"
end)
# An eval execution, will be called after the completion of each epoch
run_eval = fn state ->
model_state = state.step_state.model_state
r =
Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "VAL accuracy")
|> Axon.Loop.run(dev_batches, model_state)
IO.inspect(r)
{:continue, state}
end
# Create the training loop
Loop.trainer(model, loss_fn, optimizer)
|> Axon.Loop.handle_event(:epoch_completed, run_eval)
|> Axon.Loop.metric(:accuracy, "TRAIN accuracy")
|> Loop.run(train_batches, model_info.params, epochs: config.epochs, compiler: EXLA)
end
# Function to create batches from images and labels
defp create_batches(config, {images, labels}) do
Enum.zip(images, labels)
|> Enum.chunk_every(config.batch_size, config.batch_size, :discard)
|> Enum.map(fn chunk ->
{images, labels} = Enum.unzip(chunk)
{Nx.stack(images), Nx.stack(labels)}
end)
end
end
# Load the pretrained ResNet, but for classification for our labels
model_info = ClassificationLabels.get() |> ResNet.model_for()
t = Nx.template({16, 224, 224, 3}, :f32)
# Axon.Display.as_graph(model_info.model, t)
# Load our image data sets
{train_data, dev_data} = ImageLoader.train_val_split(image_loader)
{images, labels} = train_data
IO.inspect("Number of training images")
IO.inspect(Enum.count(images))
{images, labels} = dev_data
IO.inspect("Number of validation images")
IO.inspect(Enum.count(images))
:ok
Training Loop
This initiates the training loop.
I ran this numerous times with different learning rates, different number of epochs, and with and without freezing the weights of the underlying ResNet.
The best training run (without freezing the params, SGD at 1e-3) yielded train of 74% and validation of 76%. I stopped it after 11 epochs as I feared it was about to start to overfit. I saved this model and pushed it to S3 with the ModelPersistence
model below.
# Train the model
Trainer.train_model(config, train_data, dev_data, model_info)
Epoch: 11, Batch: 0, TRAIN accuracy: 0.74030461932262534
Batch: 4, VAL accuracy: 0.76150209234934002
# Save the model to disk, but most importantly to an S3 bucket
defmodule ModelPersistence do
def save(model_info, bucket) do
params = model_info.params
serialized_params = Nx.serialize(params, [])
File.write!("model.axon", serialized_params)
# Reload it from the file
read_file = File.read!("model.axon")
# Fetch the AWS key and secret
aws_access_key_id = System.fetch_env!("LB_AWS_ACCESS_KEY_ID")
aws_secret_access_key = System.fetch_env!("LB_AWS_SECRET_ACCESS_KEY")
# Configure ExAws
aws_config = %{
access_key_id: aws_access_key_id,
secret_access_key: aws_secret_access_key,
# Change the region to your preference
region: "us-east-1"
}
# Set the configuration for ExAws
Application.put_env(:ex_aws, :access_key_id, aws_config[:access_key_id])
Application.put_env(:ex_aws, :secret_access_key, aws_config[:secret_access_key])
Application.put_env(:ex_aws, :region, aws_config[:region])
ExAws.S3.put_object(bucket, "model.axon", read_file)
|> ExAws.request()
end
end
ModelPersistence.save(model_info, config.bucket)
Interactive Usage
Here we can easily upload an image and test out our model’s classification.
image_input = Kino.Input.image("Image", size: {224, 224}, format: :jpg)
form = Kino.Control.form([image: image_input], submit: "Run")
frame = Kino.Frame.new()
params = model_info.params
model = model_info.model
Kino.listen(form, fn %{data: %{image: image}} ->
if image do
Kino.Frame.render(frame, Kino.Text.new("Running..."))
batched_image =
image.file_ref
|> Kino.Input.file_path()
|> ImageLoader.image_from_path()
|> Nx.new_axis(0)
logits = Axon.predict(model, params, batched_image).logits
softmax = fn t ->
exp_tensor = Nx.exp(t)
sum_exp = Nx.sum(exp_tensor, axes: [-1], keep_axes: true)
Nx.divide(exp_tensor, sum_exp)
end
# Apply softmax to logits
probabilities = softmax.(logits) |> Nx.to_flat_list()
items = Enum.zip(probabilities, ClassificationLabels.get())
IO.inspect(items)
# Kino.HTML.new("#{items}
")
end
end)
Kino.Layout.grid([form, frame], boxed: true, gap: 16)
Results / Commentary
Pros:
- Bumblebee is fantastic. It provides an easy (and interactive) way to access, use and explore pre-trained vision, text and other models.
Cons:
-
There was barely enough documentation and examples
online that showed how one can use transfer learning
and repurpose an existing model like ResNet to a different task.
I spent a significant amount of time trying to get the
loss function working, finally realizing that
sparse: true
andfrom_logits: true
has to be specified.
loss_fn = fn targets, preds ->
Axon.Losses.categorical_cross_entropy(
targets,
preds,
reduction: :sum,
sparse: true,
from_logits: true
)
end
- I couldn’t find a way to simply mount my Google Drive, so instead I had to rely on uploading the fine tuned weights to an S3 bucket. This was only a minor inconvienance.
Takeaways
While this work proved to be successful and I was indeed able to fine tune ResNet for a transfer learning task, I cannot image doing this again in Elixir. There is just such a massively broader variety of examples and documentation available for PyTorch for this type of task, that it is ridiculous to spend time trying to figure things out in Elixir.