Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Ortex ResNet

livebooks/ortex/ortex_resnet.livemd

Ortex ResNet

Mix.install(
  [
    {:exla, "~> 0.8"},
    {:stb_image, "~> 0.6"},
    {:req, "~> 0.5"},
    {:kino, "~> 0.14"},
    {:ortex, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Download models

classes =
  "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
  |> Req.get!()
  |> Map.get(:body)
model_path = "/tmp/resnet.onnx"

unless File.exists?(model_path) do
  "https://media.githubusercontent.com/media/onnx/models/main/validated/vision/classification/resnet/model/resnet18-v1-7.onnx?download=true"
  |> Req.get!(connect_options: [timeout: 300_000], into: File.stream!(model_path))
end

model = Ortex.load(model_path)
serving = Nx.Serving.new(Ortex.Serving, model)
img_path = "/tmp/shark.jpg"

img_tensor =
  "https://www.collinsdictionary.com/images/full/greatwhiteshark_157273892.jpg"
  |> Req.get!()
  |> then(&StbImage.read_binary!(&1.body))
  |> StbImage.resize(224, 224)
  |> StbImage.to_nx()

Kino.Image.new(img_tensor)
nx_channels = Nx.axis_size(img_tensor, 2)
img_tensor =
  case nx_channels do
    3 -> img_tensor
    4 -> Nx.slice(img_tensor, [0, 0, 0], [224, 224, 3])
  end
  |> Nx.divide(255)
  |> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
  |> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
  |> Nx.transpose()
  |> dbg()
batch = Nx.Batch.stack([img_tensor])
serving
|> Nx.Serving.run(batch)
|> Nx.backend_transfer()
|> elem(0)
|> Nx.flatten()
|> Nx.argsort()
|> Nx.reverse()
|> Nx.slice([0], [5])
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.map(fn {no, index} -> {index, Map.get(classes, to_string(no))} end)
|> dbg()