ResNet18 image classification - libtorch
File.cd!(__DIR__)
# for windows JP
System.shell("chcp 65001")
Mix.install(
[
{:nn_interp, "~> 0.1.0"},
{:cimg, "~> 0.1.20"},
{:nx, "~> 0.2.1"},
{:kino, "~> 0.6.2"}
],
system_env: [{"NNINTERP", "libtorch"}]
)
0.Original work
torchvision.models.resnet18 - pre-trained model included in Pytorch.
Thanks a lot!!!
convert the model to TorchScript with torch.jit.script().
import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet18_Weights
r18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
r18_scripted = torch.jit.script(r18)
dummy_input = torch.rand(1, 3, 224, 224)
unscripted_output = r18(dummy_input)
scripted_output = r18_scripted(dummy_input)
unscripted_top5 = F.softmax(unscripted_output, dim=1).topk(5).indices
scripted_top5 = F.softmax(scripted_output, dim=1).topk(5).indices
print('Python model top5 results:\n {}'.format(unscripted_top5))
print('TorchScript model top 5 results:\n {}'.format(scripted_top5))
r18_scripted.save('r18_scripted.pt')
Implementation with NNInterp in Elixir
1.Defining the inference module: ResNet18
-
Model
r18_scripted.pt: get from “https://github.com/shoz-f/nn-interp/releases/download/0.1.0/r18_scripted.pt” if not existed. -
Pre-processing
Resize the input image to the size{@width, @height}
and gaussian normalize. -
Post-processing
Sort outputs and take first item.
defmodule ResNet18 do
@width 224
@height 224
use NNInterp,
model: "./model/resnet18.pt",
url: "https://github.com/shoz-f/nn-interp/releases/download/0.1.0/resnet18.pt",
inputs: [f32: {1, 3, @height, @width}],
outputs: [f32: {1, 1000}]
@imagenet1000 (for item <- File.stream!("./imagenet1000.label") do
String.trim_trailing(item)
end)
|> Enum.with_index(&{&2, &1})
|> Enum.into(%{})
def apply(img, top \\ 1) do
# preprocess
bin =
CImg.builder(img)
|> CImg.resize({@width, @height})
|> CImg.to_binary([{:gauss, {{123.7, 58.4}, {116.3, 57.1}, {103.5, 57.4}}}, :nchw])
# prediction
outputs =
session()
|> NNInterp.set_input_tensor(0, bin)
|> NNInterp.invoke()
|> NNInterp.get_output_tensor(0)
|> Nx.from_binary({:f, 32})
|> Nx.reshape({1000})
# postprocess
exp = Nx.exp(outputs)
# softmax
Nx.divide(exp, Nx.sum(exp))
|> Nx.argsort(direction: :desc)
|> Nx.slice([0], [top])
|> Nx.to_flat_list()
|> Enum.map(&@imagenet1000[&1])
end
end
Launch ResNet18
.
ResNet18.start_link([])
Display the properties of the ResNet18
model.
NNInterp.info(ResNet18)
3.Let’s try it
Load a photo and apply ResNet18 to it.
unless File.exists?("lion.jpg"),
do:
NNInterp.URL.download("https://github.com/shoz-f/nn-interp/releases/download/0.1.0/lion.jpg")
img = CImg.load("lion.jpg")
Kino.render(CImg.display_kino(img, :jpeg))
ResNet18.apply(img, 3)
4.TIL ;-)
□