Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Text classification

text_classification_exla_cuda.livemd

Text classification

Mix.install(
  [
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.9", override: true},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda12"},
    {"EXLA_TARGET", "cuda"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

設定

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, bertweet} =
  Bumblebee.load_model({
    :hf,
    "finiteautomata/bertweet-base-sentiment-analysis",
    cache_dir: cache_dir
  })
{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "vinai/bertweet-base",
    cache_dir: cache_dir
  })

サービスの提供

serving = Bumblebee.Text.text_classification(bertweet, tokenizer)

分類対象テキストの準備

text_input = Kino.Input.text("TEXT", default: "I was young then.")
text = Kino.Input.read(text_input)

推論の実行

serving
|> Nx.Serving.run(text)
|> then(&Kino.DataTable.new(&1.predictions))

時間計測

1..10
|> Enum.map(fn _ ->
  {time, _} = :timer.tc(Nx.Serving, :run, [serving, text])
  time
end)
|> then(&(Enum.sum(&1) / 10))