Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Fill-mask

livebooks/bumblebee/fill_mask.livemd

Fill-mask

Mix.install(
  [
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.9", override: true},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

設定

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, bert} =
  Bumblebee.load_model({
    :hf,
    "bert-base-uncased",
    cache_dir: cache_dir
  })
{:ok, tokenizer} =
  Bumblebee.load_tokenizer({
    :hf,
    "bert-base-uncased",
    cache_dir: cache_dir
  })

サービスの提供

serving = Bumblebee.Text.fill_mask(bert, tokenizer)

マスクされた文章の準備

text_input = Kino.Input.text("マスクされた文章", default: "The most important thing in life is [MASK].")
text = Kino.Input.read(text_input)

推論の実行

serving
|> Nx.Serving.run(text)
|> then(&Kino.DataTable.new(&1.predictions))

他のモデル

serve_model = fn repository_id ->
  {:ok, model} =
    Bumblebee.load_model({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  {:ok, tokenizer} =
    Bumblebee.load_tokenizer({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  Bumblebee.Text.fill_mask(model, tokenizer)
end
"albert-base-v2"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&Kino.DataTable.new(&1.predictions))
"roberta-base"
|> serve_model.()
|> Nx.Serving.run(text)
|> then(&Kino.DataTable.new(&1.predictions))