Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Inference on TPU

notebooks/tpu.livemd

Inference on TPU

Mix.install([
  {:tflite_elixir, "~> 0.3.0"},
  {:req, "~> 0.3.0"},
  {:kino, "~> 0.9.0"}
])

Introduction

The Coral USB Accelerator adds a Coral Edge TPU to your Linux, Mac, or Windows computer so you can accelerate your machine learning models.

It’s useful to alias the module as something shorter when we make extensive use of the functions from certain modules.

alias TFLiteElixir.Coral
alias TFLiteElixir.FlatBufferModel
alias TFLiteElixir.Interpreter
alias TFLiteElixir.InterpreterBuilder
alias TFLiteElixir.Ops
alias TFLiteElixir.TFLiteTensor

Detect edge TPU devices

Coral.edge_tpu_devices()

Download data files

# /data is the writable portion of a Nerves system
downloads_dir =
  if Code.ensure_loaded?(Nerves.Runtime), do: "/data/livebook", else: System.tmp_dir!()

download = fn url ->
  save_as = Path.join(downloads_dir, URI.encode_www_form(url))
  unless File.exists?(save_as), do: Req.get!(url, output: save_as)
  save_as
end

data_files =
  [
    class_labels:
      "https://raw.githubusercontent.com/cocoa-xu/tflite_elixir/main/test/test_data/inat_bird_labels.txt",
    cpu_model:
      "https://raw.githubusercontent.com/cocoa-xu/tflite_elixir/main/test/test_data/mobilenet_v2_1.0_224_inat_bird_quant.tflite",
    tpu_model:
      "https://raw.githubusercontent.com/cocoa-xu/tflite_elixir/main/test/test_data/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite",
    test_image:
      "https://raw.githubusercontent.com/cocoa-xu/tflite_elixir/main/test/test_data/parrot.jpeg"
  ]
  |> Enum.map(fn {key, url} -> {key, download.(url)} end)
  |> Map.new()

data_files
|> Enum.map(fn {k, v} -> [name: k, location: v] end)
|> Kino.DataTable.new(name: "Data files")

Classify image

defmodule ClassifyImage do
  @moduledoc """
  Image classification mix task: `mix help classify_image`

  Command line arguments:

  - `-m`, `--model`: *Required*. File path of .tflite file.
  - `-i`, `--input`: *Required*. Image to be classified.
  - `-l`, `--labels`: File path of labels file.
  - `-k`, `--top`: Default to `1`. Max number of classification results.
  - `-t`, `--threshold`: Default to `0.0`. Classification score threshold.
  - `-c`, `--count`: Default to `1`. Number of times to run inference.
  - `-a`, `--mean`: Default to `128.0`. Mean value for input normalization.
  - `-s`, `--std`: Default to `128.0`. STD value for input normalization.
  - `-j`, `--jobs`: Number of threads for the interpreter (only valid for CPU).
  - `--use-tpu`: Default to false. Add this option to use Coral device.
  - `--tpu`: Default to `""`. Coral device name.
    - `""`      -- any TPU device
    - `"usb"`   -- any TPU device on USB bus
    - `"pci"`   -- any TPU device on PCIe bus
    - `":N"`    -- N-th TPU device, e.g. `":0"`
    - `"usb:N"` -- N-th TPU device on USB bus, e.g. `"usb:0"`
    - `"pci:N"` -- N-th TPU device on PCIe bus, e.g. `"pci:0"`

  Code based on [classify_image.py](https://github.com/google-coral/pycoral/blob/master/examples/classify_image.py)
  """

  def run(args) do
    default_values = [
      top: 1,
      threshold: 0.0,
      count: 1,
      mean: 128.0,
      std: 128.0,
      use_tpu: false,
      tpu: "",
      jobs: System.schedulers_online()
    ]

    args =
      Keyword.merge(args, default_values, fn _k, user, default ->
        if user == nil do
          default
        else
          user
        end
      end)

    model = load_model(args[:model])
    input_image = load_input(args[:input])
    labels = load_labels(args[:labels])

    tpu_context =
      if args[:use_tpu] do
        Coral.get_edge_tpu_context!(device: args[:tpu])
      else
        nil
      end

    interpreter = make_interpreter(model, args[:jobs], args[:use_tpu], tpu_context)
    :ok = Interpreter.allocate_tensors(interpreter)

    [input_tensor_number | _] = Interpreter.inputs!(interpreter)
    [output_tensor_number | _] = Interpreter.outputs!(interpreter)
    input_tensor = Interpreter.tensor(interpreter, input_tensor_number)

    if input_tensor.type != {:u, 8} do
      raise ArgumentError, "Only support uint8 input type."
    end

    {h, w} =
      case input_tensor.shape do
        {_n, h, w, _c} ->
          {h, w}

        {_n, h, w} ->
          {h, w}

        shape ->
          raise RuntimeError, "not sure the input shape, got #{inspect(shape)}"
      end

    input_image = StbImage.resize(input_image, h, w)

    [scale] = input_tensor.quantization_params.scale
    [zero_point] = input_tensor.quantization_params.zero_point
    mean = args[:mean]
    std = args[:std]

    if abs(scale * std - 1) < 0.00001 and abs(mean - zero_point) < 0.00001 do
      # Input data does not require preprocessing.
      %StbImage{data: input_data} = input_image
      input_data
    else
      # Input data requires preprocessing
      StbImage.to_nx(input_image)
      |> Nx.subtract(mean)
      |> Nx.divide(std * scale)
      |> Nx.add(zero_point)
      |> Nx.clip(0, 255)
      |> Nx.as_type(:u8)
      |> Nx.to_binary()
    end
    |> then(&amp;TFLiteTensor.set_data(input_tensor, &amp;1))

    IO.puts("----INFERENCE TIME----")

    for _ <- 1..args[:count] do
      start_time = :os.system_time(:microsecond)
      Interpreter.invoke!(interpreter)
      end_time = :os.system_time(:microsecond)
      inference_time = (end_time - start_time) / 1000.0
      IO.puts("#{Float.round(inference_time, 1)}ms")
    end

    output_data = Interpreter.output_tensor!(interpreter, 0)
    output_tensor = Interpreter.tensor(interpreter, output_tensor_number)
    scores = get_scores(output_data, output_tensor)
    sorted_indices = Nx.argsort(scores, direction: :desc)
    top_k = Nx.take(sorted_indices, Nx.iota({args[:top]}))
    scores = Nx.to_flat_list(Nx.take(scores, top_k))
    top_k = Nx.to_flat_list(top_k)

    IO.puts("-------RESULTS--------")

    if labels != nil do
      Enum.zip(top_k, scores)
      |> Enum.each(fn {class_id, score} ->
        IO.puts("#{Enum.at(labels, class_id)}: #{Float.round(score, 5)}")
      end)
    else
      Enum.zip(top_k, scores)
      |> Enum.each(fn {class_id, score} ->
        IO.puts("#{class_id}: #{Float.round(score, 5)}")
      end)
    end

    interpreter
  end

  defp load_model(nil) do
    raise ArgumentError, "empty value for argument '--model'"
  end

  defp load_model(model_path) do
    FlatBufferModel.build_from_buffer(File.read!(model_path))
  end

  defp load_input(nil) do
    raise ArgumentError, "empty value for argument '--input'"
  end

  defp load_input(input_path) do
    with {:ok, input_image} <- StbImage.read_file(input_path) do
      input_image
    else
      {:error, error} ->
        raise RuntimeError, error
    end
  end

  defp load_labels(nil), do: nil

  defp load_labels(label_file_path) do
    File.read!(label_file_path)
    |> String.split("\n")
  end

  defp make_interpreter(model, num_jobs, false, _tpu_context) do
    resolver = Ops.Builtin.BuiltinResolver.new!()
    builder = InterpreterBuilder.new!(model, resolver)
    interpreter = Interpreter.new!()
    InterpreterBuilder.set_num_threads!(builder, num_jobs)
    :ok = InterpreterBuilder.build!(builder, interpreter)
    Interpreter.set_num_threads!(interpreter, num_jobs)
    interpreter
  end

  defp make_interpreter(model, _num_jobs, true, tpu_context) do
    Coral.make_edge_tpu_interpreter!(model, tpu_context)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype = {:u, _}} = output_tensor) do
    scale = Nx.tensor(output_tensor.quantization_params.scale)
    zero_point = Nx.tensor(output_tensor.quantization_params.zero_point)

    Nx.from_binary(output_data, dtype)
    |> Nx.as_type({:s, 64})
    |> Nx.subtract(zero_point)
    |> Nx.multiply(scale)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype = {:s, _}} = output_tensor) do
    [scale] = output_tensor.quantization_params.scale
    [zero_point] = output_tensor.quantization_params.zero_point

    Nx.from_binary(output_data, dtype)
    |> Nx.as_type({:s, 64})
    |> Nx.subtract(zero_point)
    |> Nx.multiply(scale)
  end

  defp get_scores(output_data, %TFLiteTensor{type: dtype}) do
    Nx.from_binary(output_data, dtype)
  end
end
StbImage.read_file!(data_files.test_image)

without TPU

ClassifyImage.run(
  model: data_files.cpu_model,
  input: data_files.test_image,
  labels: data_files.class_labels,
  top: 3,
  threshold: 0.3,
  count: 5,
  mean: 128.0,
  std: 128.0,
  use_tpu: false,
  tpu: ""
)

with TPU

if length(TFLiteElixir.Coral.edge_tpu_devices()) > 0 do
  ClassifyImage.run(
    model: data_files.tpu_model,
    input: data_files.test_image,
    labels: data_files.class_labels,
    top: 3,
    threshold: 0.3,
    count: 5,
    mean: 128.0,
    std: 128.0,
    use_tpu: true,
    tpu: "usb"
  )
else
  IO.puts("No edge TPU device was found")
end