Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Object detection with TensorFlow Lite

notebooks/object_detection.livemd

Object detection with TensorFlow Lite

Mix.install([
  {:tflite_elixir, "~> 0.3.4"},
  {:evision, "0.1.31"},
  {:req, "~> 0.3.0"},
  {:kino, "~> 0.9.0"}
])

Introduction

Given an image or a video stream, an object detection model can identify which of a known set of objects might be present and provide information about their positions within the image.

https://www.tensorflow.org/lite/examples/object_detection/overview

It’s useful to alias the module as something shorter when we make extensive use of the functions from certain modules.

alias Evision, as: Cv
alias TFLiteElixir, as: TFLite
alias TFLiteElixir.TFLiteTensor

Download data files

# /data is the writable portion of a Nerves system
downloads_dir =
  if Code.ensure_loaded?(Nerves.Runtime), do: "/data/livebook", else: File.cwd!()

download = fn url ->
  save_as = Path.join(downloads_dir, URI.encode_www_form(url))
  unless File.exists?(save_as), do: Req.get!(url, output: save_as)
  save_as
end

data_files =
  [
    model:
      "https://tfhub.dev/tensorflow/lite-model/efficientdet/lite4/detection/metadata/2?lite-format=tflite",
    test_img: "https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg"
  ]
  |> Enum.map(fn {key, url} -> {key, download.(url)} end)
  |> Map.new()

data_files
|> Enum.map(fn {k, v} -> [name: k, location: v] end)
|> Kino.DataTable.new(name: "Data files")

Load labels

model_buffer = File.read!(data_files.model)

class_names =
  TFLite.FlatBufferModel.get_associated_file(model_buffer, "labelmap.txt")
  |> String.split("\n")

Load input image

test_image_input = Kino.Input.image(" ")
test_image_mat =
  case Kino.Input.read(test_image_input) do
    %{data: data, height: height, width: width} ->
      Cv.Mat.from_binary(data, {:u, 8}, height, width, 3)
      |> Cv.cvtColor(Cv.Constant.cv_COLOR_BGR2RGB())

    nil ->
      Cv.imread(data_files.test_img)
  end

:ok

Preprocess image

  • Preprocess the input image to feed to the TFLite model
# Image data: ByteBuffer sized HEIGHT x WIDTH x 3,
# where HEIGHT = 640 and WIDTH = 640 with values in [0, 255].
# See https://tfhub.dev/tensorflow/lite-model/efficientdet/lite4/detection/default
test_image_resized_mat =
  test_image_mat
  |> Cv.resize({640, 640})

input_image_tensor =
  test_image_resized_mat
  |> Cv.Mat.to_nx(Nx.BinaryBackend)
  |> Nx.new_axis(0)
  |> Nx.as_type({:u, 8})

[
  ["Input image", test_image_mat],
  ["Preprocessed", test_image_resized_mat]
]
|> Enum.map(fn [label, img] ->
  Kino.Layout.grid([img, Kino.Markdown.new("**#{label}**")], boxed: true)
end)
|> Kino.Layout.grid(columns: 2)

Detect objects using TensorFlow Lite

set_input_tensor = fn interpreter, input_image_tensor ->
  TFLite.Interpreter.input_tensor(interpreter, 0, Nx.to_binary(input_image_tensor))
end

get_output_tensor_at_index = fn interpreter, index ->
  {:ok, data} = TFLite.Interpreter.output_tensor(interpreter, index)
  {:ok, output_tensor_indices} = TFLite.Interpreter.outputs(interpreter)
  tensor_index = Enum.at(output_tensor_indices, index)
  tflite_tensor = TFLite.Interpreter.tensor(interpreter, tensor_index)
  [1 | tensor_shape] = TFLite.TFLiteTensor.dims(tflite_tensor)

  data
  |> Nx.from_binary(tflite_tensor.type)
  |> Nx.reshape(List.to_tuple(tensor_shape))
end

detect_objects = fn interpreter, input_image_tensor, score_threshold ->
  # Run inference
  set_input_tensor.(interpreter, input_image_tensor)
  TFLite.Interpreter.invoke(interpreter)

  # Extract the output and postprocess it
  boxes = get_output_tensor_at_index.(interpreter, 0) |> Nx.to_list()
  class_ids = get_output_tensor_at_index.(interpreter, 1) |> Nx.to_list()
  scores = get_output_tensor_at_index.(interpreter, 2) |> Nx.to_list()
  _num_detections = get_output_tensor_at_index.(interpreter, 3) |> Nx.to_number()

  [boxes, scores, class_ids]
  |> Enum.zip_reduce([], fn
    [box, score, class_id], acc when score >= score_threshold ->
      [%{box: List.to_tuple(box), score: score, class_id: trunc(class_id)} | acc]

    _, acc ->
      acc
  end)
end

{:ok, interpreter} = TFLite.Interpreter.new_from_buffer(model_buffer)
prediction_results = detect_objects.(interpreter, input_image_tensor, 0.5)
Kino.DataTable.new(prediction_results)

Visualize predictions

  • Draw the detection results on the original image
{img_height, img_width, _} = Cv.Mat.shape(test_image_mat)

# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
calc_prediction_box = fn {y_min, x_min, y_max, x_max} ->
  x_min = round(x_min * img_width)
  x_max = round(x_max * img_width)
  y_min = round(y_min * img_height)
  y_max = round(y_max * img_height)

  {x_min, y_max, x_max, y_min}
end

draw_result = fn %{class_id: class_id, score: score, box: box}, acc_mat ->
  {x_min, y_max, x_max, y_min} = calc_prediction_box.(box)
  class_name = Enum.at(class_names, class_id)
  score_percent = round(score * 100)

  box_start_point = {x_min, y_max}
  box_end_point = {x_max, y_min}
  box_color = {0, 255, 0}

  label_text = "#{class_name}: #{score_percent}%"
  label_start_point = {x_min + 6, y_min - 10}
  label_font_scale = 0.7
  label_color = {0, 255, 0}

  acc_mat
  |> Cv.rectangle(
    box_start_point,
    box_end_point,
    box_color,
    thickness: 2
  )
  |> Cv.putText(
    label_text,
    label_start_point,
    Cv.Constant.cv_FONT_HERSHEY_SIMPLEX(),
    label_font_scale,
    label_color,
    thickness: 2
  )
end

for prediction_result <- prediction_results, reduce: test_image_mat do
  acc_mat -> draw_result.(prediction_result, acc_mat)
end