Image Accessibility
Mix.install(
[
{:kino_bumblebee, "~> 0.5.0"},
{:exla, ">= 0.0.0"},
{:axon, "~> 0.7.0"},
{:table_rex, "~> 3.1.1"},
{:polaris, "~> 0.1.0"},
{:evision, "~> 0.2.9"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Section
config = %{
epochs: 500,
batch_size: 1,
data_root: "/home/livebook"
#data_root: "/Users/darren/dev/11785-project"
}
defmodule ClassificationLabels do
@labels [
"table",
"other",
"text",
"code"
]
def get(), do: @labels
end
defmodule ResNet do
def model_for(labels) do
id_to_label =
labels
|> Enum.with_index()
|> Enum.reduce(%{}, fn {label, idx}, m -> Map.put(m, idx, label) end)
num_labels = Enum.count(labels)
{:ok, spec} =
Bumblebee.load_spec({:hf, "microsoft/resnet-50"},
architecture: :for_image_classification
)
spec = Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
{:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
model_info
end
end
defmodule RawData do
def ensure_available(config) do
if File.exists?("#{config.data_root}/images") do
:exists
else
download(config) |> untar(config)
end
end
defp download(config) do
url = "https://s3.amazonaws.com/votre-montreal.com/images.tar.gz"
destination = "#{config.data_root}/images.tar.gz"
#{_, exit_status} = System.cmd("curl", ["-o", destination, url])
{_, exit_status} = System.cmd("wget", [url, "-O", destination])
if exit_status == 0 do
:ok
else
:error
end
end
defp untar(:error, _), do: :error
defp untar(:ok, config) do
{_, exit_status} = System.cmd("tar", ["-xzvf", "#{config.data_root}/images.tar.gz"])
if exit_status == 0 do
:ok
else
:error
end
end
end
RawData.ensure_available(config)
defmodule ImageLoader do
@image_size {224, 224}
@mean [0.485, 0.456, 0.406]
@std [0.229, 0.224, 0.225]
def load_images_from_dir(base_dir) do
ClassificationLabels.get()
|> Enum.reduce({[], []}, fn label, {images_acc, labels_acc} ->
# Get the path to the subdirectory
label_path = Path.join(base_dir, label)
# Get the index of the label
label_idx = ClassificationLabels.get() |> Enum.find_index(fn l -> l == label end)
# Find all images in the subdirectory
images =
label_path
|> File.ls!() # List all files in the directory
|> Enum.filter(fn f ->
!String.starts_with?(f, ".") and
!String.ends_with?(f, ".svg") and
!String.ends_with?(f, ".DS_Store") and
!String.ends_with?(f, ".avif") and
!String.ends_with?(f, ".gif")
end)
|> Enum.map(fn image_file ->
image_path = Path.join(label_path, image_file)
# Load the image and resize to 224x224
image = Evision.imread(image_path)
resized_image = Evision.resize(image, @image_size)
# Convert the image to Nx tensor and normalize it
Nx.from_binary(Evision.Mat.to_binary(resized_image), {:u, 8})
|> Nx.reshape({@image_size |> elem(0), @image_size |> elem(1), 3})
|> Nx.divide(255.0) # Scale pixel values from [0, 255] to [0, 1]
|> normalize_image()
end)
# Append to the accumulators
{images_acc ++ images, labels_acc ++ List.duplicate(label_idx, length(images))}
end)
end
def train_val_split({images, labels}, train_pct \\ 0.5) do
count_to_take = (Enum.count(images) * train_pct) |> trunc
{train, dev} = Enum.zip(images, labels)
|> Enum.shuffle()
|> Enum.split(count_to_take)
{Enum.unzip(train), Enum.unzip(dev)}
end
# Function to normalize image based on mean and std for each channel
defp normalize_image(image_tensor) do
mean_tensor = Nx.tensor(@mean, backend: Nx.BinaryBackend)
std_tensor = Nx.tensor(@std, backend: Nx.BinaryBackend)
image_tensor
|> Nx.subtract(mean_tensor)
|> Nx.divide(std_tensor)
end
end
defmodule Trainer do
alias Axon.Loop
# Training loop function
def train_model(config, train_data, dev_data, model_info) do
# Ensure both images and labels are Nx tensors
model = model_info.model
|> Axon.nx(fn %{logits: logits} -> logits end)
# Split data into batches
train_batches = create_batches(config, train_data)
dev_batches = create_batches(config, dev_data)
# The model is already loaded, we just configure the optimizer, loss, and metrics
optimizer = Polaris.Optimizers.sgd(learning_rate: 1.0e-1)
loss_fn = fn targets, preds ->
Axon.Losses.categorical_cross_entropy(
targets, preds, reduction: :sum, sparse: true, from_logits: true)
end
frozen_params = Axon.ModelState.freeze(model_info.params, fn [path, _a] ->
path != "image_classification_head.output"
end)
# An eval execution, will be called after the completion of each epoch
run_eval = fn state ->
model_state = state.step_state.model_state
r = Axon.Loop.evaluator(model)
|> Axon.Loop.metric(:accuracy, "VAL accuracy")
|> Axon.Loop.run(dev_batches, model_state)
IO.inspect(r)
{:continue, state}
end
# Create the training loop
Loop.trainer(model, loss_fn, optimizer)
|> Axon.Loop.handle_event(:epoch_completed, run_eval)
|> Axon.Loop.metric(:accuracy, "TRAIN accuracy")
|> Loop.run(train_batches, model_info.params, epochs: config.epochs, compiler: EXLA)
end
# Function to create batches from images and labels
defp create_batches(config, {images, labels}) do
Enum.zip(images, labels)
|> Enum.chunk_every(config.batch_size, config.batch_size, :discard)
|> Enum.map(fn chunk ->
{images, labels} = Enum.unzip(chunk)
{Nx.stack(images), Nx.stack(labels)}
end)
end
end
# Load the pretrained ResNet, but for classification for our labels
model_info = ClassificationLabels.get() |> ResNet.model_for()
# = Nx.template({16, 224, 224, 3}, :f32)
#Axon.Display.as_graph(model_info.model, t)
# Load our image data sets
{train_data, dev_data} = ImageLoader.load_images_from_dir("#{config.data_root}/images")
|> ImageLoader.train_val_split()
{images, labels} = train_data
IO.inspect(Enum.count(images))
{images, labels} = dev_data
IO.inspect(Enum.count(images))
# Train the model
Trainer.train_model(config, train_data, dev_data, model_info)