Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Segment Anything Model (SAM) using Ortex

livebook/segment_anything.livemd

Segment Anything Model (SAM) using Ortex

Mix.install([
  {:ortex, "~> 0.1.9"},
  {:image, "~> 0.47"},
  {:nx_image, "~> 0.1.2"},
  {:exla, "~> 0.7.2"},
  {:kino, "~> 0.12.3"},
  {:req, "~> 0.4"}
])

Recap

This is an implementation by @herisson of the Segment Anything Model (SAM) from facebook research using the Ortex library to run ONNX models.

It is approximately a port this jupyter notebook example to livebook: https://colab.research.google.com/drive/1wmjHHcrZ_s8iFuVFh9iHo6GbUS_xH5xq

It uses the onnx “mobile” models found here.

Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()

Loading the models

encoder_url =
  "https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vision_encoder_quantized.onnx"

encoder_path = "/tmp/vision_encoder_quantized.onnx"

%{body: body} = Req.get!(encoder_url)
File.write!(encoder_path, body)

model = Ortex.load(encoder_path)
decoder_url =
  "https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vit_b_decoder.onnx"

decoder_path = "/tmp/vit_b_decoder.onnx"

%{body: body} = Req.get!(decoder_url)
File.write!(decoder_path, body)

decoder = Ortex.load(decoder_path)

Getting the image

image_input = Kino.Input.image("Uploaded Image")
# https://ibb.co/Hr8588Y the image I'm using
image =
  Kino.Input.read(image_input)
  |> Image.from_kino!()

# The demo image is already these dimensions
# but other images may not be.
resized =
  Image.thumbnail!(image, "1024x1024")

original_label = Kino.Markdown.new("**Original image**")
resized_label = Kino.Markdown.new("**Resized image**")

Kino.Layout.grid(
  [
    Kino.Layout.grid([image, original_label], boxed: true),
    Kino.Layout.grid([resized, resized_label], boxed: true)
  ],
  columns: 2
)
# Send the image to Nx. This is a zero-copy
# operation - basically just give the pointer
# to a memory area to Nx.
tensor =
  resized
  |> Image.to_nx!()
  |> Nx.as_type(:f32)

# Mean and std values copied from transformer.js
mean = Nx.tensor([123.675, 116.28, 103.53])
std = Nx.tensor([58.395, 57.12, 57.375])

normalized_tensor =
  tensor
  |> NxImage.normalize(mean, std)
  |> Nx.tensor(names: [:height, :width, :bands])
  |> Nx.transpose(axes: [:bands, :height, :width])

# Running image encoder / outputs depends on the onnx model
{image_embeddings, _} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1, 3, 1024, 1024}))
# {image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1024, 1024, 3}))

Prompt encoding & mask generation

# prepare inputs
# xy box coordinates in our image of the object we want to detour
# input_point = Nx.tensor([[345, 272], [640, 760]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
# 2, 3 is for box startig / end points
# input_label = Nx.tensor([2, 3]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)
# single point
input_point = Nx.tensor([[514, 514], [0, 0]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2})
input_label = Nx.tensor([1, -1]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32)
# Filled with 0, not used here
mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32)
# not using mask_input
has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32)
original_image_dim = Nx.tensor([1024, 1024]) |> Nx.as_type(:f32)

{mask, scores, low_res} =
  Ortex.run(decoder, {
    Nx.broadcast(image_embeddings, {1, 256, 64, 64}),
    Nx.broadcast(input_point, {1, 2, 2}),
    Nx.broadcast(input_label, {1, 2}),
    Nx.broadcast(mask_input, {1, 1, 256, 256}),
    Nx.broadcast(has_mask, {1}),
    Nx.broadcast(original_image_dim, {2})
  })

Transform output to a black an white image

use Image.Math

mask =
  mask[0][2]
  |> Nx.reshape({1024, 1024, 1})
  |> Image.from_nx!()

low_res =
  low_res[0][2]
  |> Nx.reshape({256, 256, 1})
  |> Image.from_nx!()

# Threshold the masks to produce black and white
# This is one NIF call only. Then then take only the
# first band which we'll use as the alpha band.
mask = Image.if_then_else!(mask > 0, 255, 0)[0]
low_res_mask = Image.if_then_else!(mask > 0, 255, 0)[0]

mask_label = Kino.Markdown.new("**Mask**")
low_res_label = Kino.Markdown.new("**Low Res**")

Kino.Layout.grid(
  [
    Kino.Layout.grid([mask, mask_label], boxed: true),
    Kino.Layout.grid([low_res_mask, low_res_label], boxed: true)
  ],
  columns: 2
)

Showing the masks

As you can see the high res mask is distorted / badly placed. However, the low res mask fits corectly ??

new_image = Image.add_alpha!(image, mask)
mask_label = Kino.Markdown.new("**Image mask**")
new_image_label = Kino.Markdown.new("**new image**")

Kino.Layout.grid(
  [
    Kino.Layout.grid([image, original_label], boxed: true),
    Kino.Layout.grid([mask, mask_label], boxed: true),
    Kino.Layout.grid([new_image, new_image_label], boxed: true)
  ],
  columns: 3
)