Image Classification
Mix.install(
[
{:bumblebee, "~> 0.5"},
{:nx, "~> 0.9", override: true},
{:exla, "~> 0.9"},
{:kino, "~> 0.14"}
],
system_env: [
{"XLA_TARGET", "cuda12"},
{"EXLA_TARGET", "cuda"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
設定
Nx.default_backend()
cache_dir = "/tmp/bumblebee_cache"
モデルのダウンロード
{:ok, resnet} =
Bumblebee.load_model({
:hf,
"microsoft/resnet-50",
cache_dir: cache_dir
})
{:ok, featurizer} =
Bumblebee.load_featurizer({
:hf,
"microsoft/resnet-50",
cache_dir: cache_dir
})
画像分類の実行
画像の準備
image_input = Kino.Input.image("IMAGE", size: {224, 224})
image =
image_input
|> Kino.Input.read()
|> then(fn input ->
input
|> Map.get(:file_ref)
|> Kino.Input.file_path()
|> File.read!()
|> Nx.from_binary(:u8)
|> Nx.reshape({input.height, input.width, 3})
end)
Kino.Image.new(image)
手動推論
inputs = Bumblebee.apply_featurizer(featurizer, image)
outputs = Axon.predict(resnet.model, resnet.params, inputs)
outputs.logits
|> Nx.squeeze()
|> Axon.Activations.softmax()
|> Bumblebee.Utils.Nx.top_k(k: 5)
|> then(fn {scores, class_ids} ->
scores
|> Nx.to_flat_list()
|> Enum.zip(Nx.to_flat_list(class_ids))
|> Enum.map(fn {score, class_id} ->
[
label: resnet.spec.id_to_label[class_id],
score: score
]
end)
end)
|> Kino.DataTable.new()
|> dbg()
Nx.Serving による提供
serving = Bumblebee.Vision.image_classification(resnet, featurizer)
serving
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))
時間計測
1..10
|> Enum.map(fn _ ->
{time, _} = :timer.tc(Nx.Serving, :run, [serving, image])
time
end)
|> then(&(Enum.sum(&1) / 10))
他のモデル
serve_model = fn repository_id ->
{:ok, model} =
Bumblebee.load_model({
:hf,
repository_id,
cache_dir: cache_dir
})
{:ok, featurizer} =
Bumblebee.load_featurizer({
:hf,
repository_id,
cache_dir: cache_dir
})
Bumblebee.Vision.image_classification(model, featurizer)
end
"facebook/convnext-tiny-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))
"google/vit-base-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))
"facebook/deit-base-distilled-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))