Notesclub

Image Classification

image_classification_exla_cuda.livemd

Image Classification

Mix.install(
  [
    {:bumblebee, "~> 0.1"},
    {:nx, "~> 0.4"},
    {:exla, "~> 0.4"},
    {:kino, "~> 0.8"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda114"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend
    ]
  ]
)

設定

Nx.default_backend()
cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, resnet} =
  Bumblebee.load_model({
    :hf,
    "microsoft/resnet-50",
    cache_dir: cache_dir
  })
{:ok, featurizer} =
  Bumblebee.load_featurizer({
    :hf,
    "microsoft/resnet-50",
    cache_dir: cache_dir
  })

画像分類の実行

画像の準備

image_input = Kino.Input.image("IMAGE", size: {224, 224})
image =
  image_input
  |> Kino.Input.read()
  |> then(fn input ->
    input.data
    |> Nx.from_binary(:u8)
    |> Nx.reshape({input.height, input.width, 3})
  end)

Kino.Image.new(image)

手動推論

inputs = Bumblebee.apply_featurizer(featurizer, image)
outputs = Axon.predict(resnet.model, resnet.params, inputs)
outputs.logits
|> Nx.squeeze()
|> Axon.Activations.softmax()
|> Bumblebee.Utils.Nx.top_k(k: 5)
|> then(fn {scores, class_ids} ->
  scores
  |> Nx.to_flat_list()
  |> Enum.zip(Nx.to_flat_list(class_ids))
  |> Enum.map(fn {score, class_id} ->
    [
      label: resnet.spec.id_to_label[class_id],
      score: score
    ]
  end)
end)
|> Kino.DataTable.new()
|> dbg()

Nx.Serving による提供

serving = Bumblebee.Vision.image_classification(resnet, featurizer)
serving
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))

時間計測

1..10
|> Enum.map(fn _ ->
  {time, _} = :timer.tc(Nx.Serving, :run, [serving, image])
  time
end)
|> then(&(Enum.sum(&1) / 10))

他のモデル

serve_model = fn repository_id ->
  {:ok, model} =
    Bumblebee.load_model({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  {:ok, featurizer} =
    Bumblebee.load_featurizer({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  Bumblebee.Vision.image_classification(model, featurizer)
end
"facebook/convnext-tiny-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))
"google/vit-base-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))
"facebook/deit-base-distilled-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))