Image classification with TensorFlow Lite
Mix.install([
# will download and install precompiled version
{:tflite_elixir, "~> 0.3.4"},
{:req, "~> 0.3.0"},
{:progress_bar, "~> 2.0.0"},
{:kino, "~> 0.9.0"}
])
Introduction
In this notebook, we will perform image classification with pre-trained mobilenet_v2_1.0_224_inat_bird_quant.tflite model.
https://www.tensorflow.org/lite/examples/image_classification/overview
Prepare helper functions
defmodule Utils do
def download!(source_url, req_options \\ []) do
Req.get!(source_url, [finch_request: &finch_request/4] ++ req_options).body
end
defp finch_request(req_request, finch_request, finch_name, finch_options) do
acc = Req.Response.new()
case Finch.stream(finch_request, finch_name, acc, &handle_message/2, finch_options) do
{:ok, response} -> {req_request, response}
{:error, exception} -> {req_request, exception}
end
end
defp handle_message({:status, status}, response), do: %{response | status: status}
defp handle_message({:headers, headers}, response) do
{_, total_size} = Enum.find(headers, &match?({"content-length", _}, &1))
response
|> Map.put(:headers, headers)
|> Map.put(:private, %{total_size: String.to_integer(total_size), downloaded_size: 0})
end
defp handle_message({:data, data}, response) do
new_downloaded_size = response.private.downloaded_size + byte_size(data)
ProgressBar.render(new_downloaded_size, response.private.total_size, suffix: :bytes)
response
|> Map.update!(:body, &(&1 <> data))
|> Map.update!(:private, &%{&1 | downloaded_size: new_downloaded_size})
end
end
Decide on where downloaded files are saved
# /data is the writable portion of a Nerves system
downloads_dir =
if Code.ensure_loaded?(Nerves.Runtime),
do: "/data/tmp",
else: File.cwd!()
Download pre-trained model
model_url =
"https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite"
model_file = Path.join(downloads_dir, "mobilenet_v2_1.0_224_inat_bird_quant.tflite")
unless File.exists?(model_file), do: Utils.download!(model_url, output: model_file)
IO.puts("Model saved to #{model_file}")
Download labels
# Each line corresponds to a class name. First line is ID 0.
labels_url =
"https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt"
labels_file = Path.join(downloads_dir, "inat_bird_labels.txt")
unless File.exists?(labels_file), do: Utils.download!(labels_url, output: labels_file)
IO.puts("Labels saved to #{labels_file}")
labels = File.read!(labels_file) |> String.split("\n", trim: true)
Choose image to be classified
An input image can be uploaded here, or default parrot image will be used.
image_input = Kino.Input.image("Image", size: {224, 224})
default_input_image_url =
"https://raw.githubusercontent.com/google-coral/test_data/master/parrot.jpg"
input_image_nx =
if uploaded_image = Kino.Input.read(image_input) do
uploaded_image.data
|> Nx.from_binary(:u8)
|> Nx.reshape({uploaded_image.height, uploaded_image.width, 3})
else
IO.puts("Loading default image from #{default_input_image_url}")
Utils.download!(default_input_image_url)
|> StbImage.read_binary!()
|> StbImage.to_nx()
end
Kino.Image.new(input_image_nx)
Classify image
alias TFLiteElixir.Interpreter
how_many_results = 3
interpreter = Interpreter.new!(model_file)
input_image_resized =
input_image_nx
|> StbImage.from_nx()
|> StbImage.resize(224, 224)
|> StbImage.to_nx()
[output_tensor_0] = Interpreter.predict(interpreter, input_image_resized)
indices_nx = Nx.flatten(output_tensor_0)
label_lookup = List.to_tuple(labels)
indices_nx
|> Nx.argsort(direction: :desc)
|> Nx.take(Nx.iota({how_many_results}))
|> Nx.to_flat_list()
|> Enum.map(&%{class_id: &1, class_name: elem(label_lookup, &1)})
|> Kino.DataTable.new(name: "Inference results")
ImageClassification module (experimental)
There is an experimental ImageClassification
module that does everything for
you. It supports both CPU and TPU, and it will show more information, including
scores (confidence) and the class name of the predicted results. It’s also more
flexible where you can adjust different parameters like top_k
and threshold
(for confidence) and etc.
alias TFLiteElixir.ImageClassification
{:ok, pid} = ImageClassification.start(model_file)
ImageClassification.set_label(pid, labels)
results = ImageClassification.predict(pid, input_image_nx, top_k: 3)
Kino.DataTable.new(results, name: "Inference results")
Some models have labels emdedded as associated
files.
If that is the case, we can load label with
ImageClassification.set_label_from_associated_file/2
.