Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Untitled notebook

notebooks/vgg16_detection.livemd

Untitled notebook

Mix.install([
  {:download, "~> 0.0.4"},
  {:evision, "~> 0.1.0-dev", github: "cocoa-xu/evision", branch: "main"},
  {:kino, "~> 0.5.2"},
  {:nx, "~> 0.1", [env: :prod, repo: "hexpm", hex: "nx", optional: true]},
  {:exla, "~> 0.2.1"},
  {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}
])

Section

alias Evision, as: OpenCV
require Axon
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
defmodule Helper do
  def show_image(mat) do
    OpenCV.imencode!(".png", mat)
    |> IO.iodata_to_binary()
    |> Kino.Image.new(:png)
  end
end
# 再実行時、Download.from()でeexistエラーになるのを防止
File.rm("Lenna_%28test_image%29.png")

lenna =
  Download.from("https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png")
  |> elem(1)
lenna
|> OpenCV.imread!()
|> Helper.show_image()
{:ok, params} = :dets.open_file('/data/vgg16.dets')
[{1, {model, params}}] = :dets.lookup(params, 1)
# 再実行時、Download.from()でeexistエラーになるのを防止
File.rm("imagenet1000_clsidx_to_labels.txt")

class_list =
  Download.from(
    "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
  )
  |> elem(1)
class_map =
  class_list
  |> File.read!()
  |> String.split("\n")
  |> Enum.reduce(%{}, fn line, acc ->
    [class_number, class_name] = String.split(line, ":")

    class_number =
      class_number
      |> String.replace("{", "")
      |> String.trim()
      |> Integer.parse()
      |> elem(0)

    class_name =
      class_name
      |> String.replace("}", "")
      |> String.trim()
      |> String.replace(", ", "!")
      |> String.replace(",", "")
      |> String.replace("!", ", ")
      |> String.replace("'", "")

    acc |> Map.put(class_number, class_name)
  end)
mat = OpenCV.imread!(lenna)
tensor =
  OpenCV.resize!(mat, [_width = 224, _height = 224])
  |> OpenCV.cvtColor!(OpenCV.cv_COLOR_BGR2RGB())
  |> OpenCV.Nx.to_nx()
  |> Nx.divide(255)
  |> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
  |> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
  |> Nx.transpose()
  |> Nx.new_axis(0)
preds =
  Axon.predict(model, params, tensor)
  |> Nx.flatten()
  |> Nx.argsort()
  |> Nx.reverse()
  |> Nx.slice([0], [5])
  |> Nx.to_flat_list()
preds
|> Enum.map(fn element ->
  Map.get(class_map, element)
end)
defmodule Detector do
  def detect(image_url, model, params, class_map) do
    basename =
      image_url
      |> URI.parse()
      |> Map.fetch!(:path)
      |> Path.basename()

    File.rm(basename)

    mat =
      image_url
      |> Download.from()
      |> elem(1)
      |> OpenCV.imread!()

    tensor =
      OpenCV.resize!(mat, [_width = 224, _height = 224])
      |> OpenCV.cvtColor!(OpenCV.cv_COLOR_BGR2RGB())
      |> OpenCV.Nx.to_nx()
      |> Nx.divide(255)
      |> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
      |> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
      |> Nx.transpose()
      |> Nx.new_axis(0)

    model
    |> Axon.predict(params, tensor)
    |> Nx.flatten()
    |> Nx.argsort()
    |> Nx.reverse()
    |> Nx.slice([0], [5])
    |> Nx.to_flat_list()
    |> Enum.map(fn element ->
      Map.get(class_map, element)
      |> IO.puts()
    end)

    Helper.show_image(mat)
  end
end
"https://c7.staticflickr.com/4/3348/3637369393_6428a81f05_o.jpg"
|> Detector.detect(model, params, class_map)
"https://c4.staticflickr.com/9/8518/8472820388_ddbf1e2d3a_o.jpg"
|> Detector.detect(model, params, class_map)
"https://farm4.staticflickr.com/8479/8249921544_9b51c8a1db_o.jpg"
|> Detector.detect(model, params, class_map)
"https://farm8.staticflickr.com/8232/8593290034_7251ed76b5_o.jpg"
|> Detector.detect(model, params, class_map)