Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

Image classification with TensorFlow Lite

notebooks/image_classification.livemd

Image classification with TensorFlow Lite

Mix.install([
  # will download and install precompiled version
  {:tflite_elixir, "~> 0.3.4"},
  {:req, "~> 0.3.0"},
  {:progress_bar, "~> 2.0.0"},
  {:kino, "~> 0.9.0"}
])

Introduction

In this notebook, we will perform image classification with pre-trained mobilenet_v2_1.0_224_inat_bird_quant.tflite model.

https://www.tensorflow.org/lite/examples/image_classification/overview

Prepare helper functions

defmodule Utils do
  def download!(source_url, req_options \\ []) do
    Req.get!(source_url, [finch_request: &finch_request/4] ++ req_options).body
  end

  defp finch_request(req_request, finch_request, finch_name, finch_options) do
    acc = Req.Response.new()

    case Finch.stream(finch_request, finch_name, acc, &handle_message/2, finch_options) do
      {:ok, response} -> {req_request, response}
      {:error, exception} -> {req_request, exception}
    end
  end

  defp handle_message({:status, status}, response), do: %{response | status: status}

  defp handle_message({:headers, headers}, response) do
    {_, total_size} = Enum.find(headers, &match?({"content-length", _}, &1))

    response
    |> Map.put(:headers, headers)
    |> Map.put(:private, %{total_size: String.to_integer(total_size), downloaded_size: 0})
  end

  defp handle_message({:data, data}, response) do
    new_downloaded_size = response.private.downloaded_size + byte_size(data)
    ProgressBar.render(new_downloaded_size, response.private.total_size, suffix: :bytes)

    response
    |> Map.update!(:body, &amp;(&amp;1 <> data))
    |> Map.update!(:private, &amp;%{&amp;1 | downloaded_size: new_downloaded_size})
  end
end

Decide on where downloaded files are saved

# /data is the writable portion of a Nerves system
downloads_dir =
  if Code.ensure_loaded?(Nerves.Runtime),
    do: "/data/tmp",
    else: File.cwd!()

Download pre-trained model

model_url =
  "https://raw.githubusercontent.com/google-coral/test_data/master/mobilenet_v2_1.0_224_inat_bird_quant.tflite"

model_file = Path.join(downloads_dir, "mobilenet_v2_1.0_224_inat_bird_quant.tflite")
unless File.exists?(model_file), do: Utils.download!(model_url, output: model_file)
IO.puts("Model saved to #{model_file}")

Download labels

# Each line corresponds to a class name. First line is ID 0.
labels_url =
  "https://raw.githubusercontent.com/google-coral/test_data/master/inat_bird_labels.txt"

labels_file = Path.join(downloads_dir, "inat_bird_labels.txt")
unless File.exists?(labels_file), do: Utils.download!(labels_url, output: labels_file)
IO.puts("Labels saved to #{labels_file}")

labels = File.read!(labels_file) |> String.split("\n", trim: true)

Choose image to be classified

An input image can be uploaded here, or default parrot image will be used.

image_input = Kino.Input.image("Image", size: {224, 224})
default_input_image_url =
  "https://raw.githubusercontent.com/google-coral/test_data/master/parrot.jpg"

input_image_nx =
  if uploaded_image = Kino.Input.read(image_input) do
    uploaded_image.data
    |> Nx.from_binary(:u8)
    |> Nx.reshape({uploaded_image.height, uploaded_image.width, 3})
  else
    IO.puts("Loading default image from #{default_input_image_url}")

    Utils.download!(default_input_image_url)
    |> StbImage.read_binary!()
    |> StbImage.to_nx()
  end

Kino.Image.new(input_image_nx)

Classify image

alias TFLiteElixir.Interpreter

how_many_results = 3

interpreter = Interpreter.new!(model_file)

input_image_resized =
  input_image_nx
  |> StbImage.from_nx()
  |> StbImage.resize(224, 224)
  |> StbImage.to_nx()

[output_tensor_0] = Interpreter.predict(interpreter, input_image_resized)
indices_nx = Nx.flatten(output_tensor_0)

label_lookup = List.to_tuple(labels)

indices_nx
|> Nx.argsort(direction: :desc)
|> Nx.take(Nx.iota({how_many_results}))
|> Nx.to_flat_list()
|> Enum.map(&amp;%{class_id: &amp;1, class_name: elem(label_lookup, &amp;1)})
|> Kino.DataTable.new(name: "Inference results")

ImageClassification module (experimental)

There is an experimental ImageClassification module that does everything for you. It supports both CPU and TPU, and it will show more information, including scores (confidence) and the class name of the predicted results. It’s also more flexible where you can adjust different parameters like top_k and threshold (for confidence) and etc.

alias TFLiteElixir.ImageClassification

{:ok, pid} = ImageClassification.start(model_file)
ImageClassification.set_label(pid, labels)
results = ImageClassification.predict(pid, input_image_nx, top_k: 3)
Kino.DataTable.new(results, name: "Inference results")

Some models have labels emdedded as associated files. If that is the case, we can load label with ImageClassification.set_label_from_associated_file/2.