Untitled notebook
Mix.install([
{:download, "~> 0.0.4"},
{:evision, "~> 0.1.0-dev", github: "cocoa-xu/evision", branch: "main"},
{:kino, "~> 0.5.2"},
{:nx, "~> 0.1", [env: :prod, repo: "hexpm", hex: "nx", optional: true]},
{:exla, "~> 0.2.1"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}
])
Section
alias Evision, as: OpenCV
require Axon
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
defmodule Helper do
def show_image(mat) do
OpenCV.imencode!(".png", mat)
|> IO.iodata_to_binary()
|> Kino.Image.new(:png)
end
end
# 再実行時、Download.from()でeexistエラーになるのを防止
File.rm("Lenna_%28test_image%29.png")
lenna =
Download.from("https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png")
|> elem(1)
lenna
|> OpenCV.imread!()
|> Helper.show_image()
{:ok, params} = :dets.open_file('/data/vgg16.dets')
[{1, {model, params}}] = :dets.lookup(params, 1)
# 再実行時、Download.from()でeexistエラーになるのを防止
File.rm("imagenet1000_clsidx_to_labels.txt")
class_list =
Download.from(
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
)
|> elem(1)
class_map =
class_list
|> File.read!()
|> String.split("\n")
|> Enum.reduce(%{}, fn line, acc ->
[class_number, class_name] = String.split(line, ":")
class_number =
class_number
|> String.replace("{", "")
|> String.trim()
|> Integer.parse()
|> elem(0)
class_name =
class_name
|> String.replace("}", "")
|> String.trim()
|> String.replace(", ", "!")
|> String.replace(",", "")
|> String.replace("!", ", ")
|> String.replace("'", "")
acc |> Map.put(class_number, class_name)
end)
mat = OpenCV.imread!(lenna)
tensor =
OpenCV.resize!(mat, [_width = 224, _height = 224])
|> OpenCV.cvtColor!(OpenCV.cv_COLOR_BGR2RGB())
|> OpenCV.Nx.to_nx()
|> Nx.divide(255)
|> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
|> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
|> Nx.transpose()
|> Nx.new_axis(0)
preds =
Axon.predict(model, params, tensor)
|> Nx.flatten()
|> Nx.argsort()
|> Nx.reverse()
|> Nx.slice([0], [5])
|> Nx.to_flat_list()
preds
|> Enum.map(fn element ->
Map.get(class_map, element)
end)
defmodule Detector do
def detect(image_url, model, params, class_map) do
basename =
image_url
|> URI.parse()
|> Map.fetch!(:path)
|> Path.basename()
File.rm(basename)
mat =
image_url
|> Download.from()
|> elem(1)
|> OpenCV.imread!()
tensor =
OpenCV.resize!(mat, [_width = 224, _height = 224])
|> OpenCV.cvtColor!(OpenCV.cv_COLOR_BGR2RGB())
|> OpenCV.Nx.to_nx()
|> Nx.divide(255)
|> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
|> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
|> Nx.transpose()
|> Nx.new_axis(0)
model
|> Axon.predict(params, tensor)
|> Nx.flatten()
|> Nx.argsort()
|> Nx.reverse()
|> Nx.slice([0], [5])
|> Nx.to_flat_list()
|> Enum.map(fn element ->
Map.get(class_map, element)
|> IO.puts()
end)
Helper.show_image(mat)
end
end
"https://c7.staticflickr.com/4/3348/3637369393_6428a81f05_o.jpg"
|> Detector.detect(model, params, class_map)
"https://c4.staticflickr.com/9/8518/8472820388_ddbf1e2d3a_o.jpg"
|> Detector.detect(model, params, class_map)
"https://farm4.staticflickr.com/8479/8249921544_9b51c8a1db_o.jpg"
|> Detector.detect(model, params, class_map)
"https://farm8.staticflickr.com/8232/8593290034_7251ed76b5_o.jpg"
|> Detector.detect(model, params, class_map)