Text generation
Mix.install(
[
{:bumblebee, "~> 0.1"},
{:nx, "~> 0.4"},
{:exla, "~> 0.4"},
{:kino, "~> 0.8"}
],
system_env: [
{"XLA_TARGET", "cuda114"}
],
config: [
nx: [
default_backend: EXLA.Backend
]
]
)
設定
cache_dir = "/tmp/bumblebee_cache"
モデルのダウンロード
{:ok, gpt2} =
Bumblebee.load_model({
:hf,
"gpt2",
cache_dir: cache_dir
})
{:ok, tokenizer} =
Bumblebee.load_tokenizer({
:hf,
"gpt2",
cache_dir: cache_dir
})
サービスの提供
serving = Bumblebee.Text.generation(gpt2, tokenizer, max_new_tokens: 10)
補完する文章の準備
text_input = Kino.Input.text("TEXT", default: "Robots have gained human rights and")
text = Kino.Input.read(text_input)
serving
|> Nx.Serving.run(text)
|> Map.get(:results)
時間計測
1..10
|> Enum.map(fn _ ->
{time, _} = :timer.tc(Nx.Serving, :run, [serving, text])
time
end)
|> then(&(Enum.sum(&1) / 10))
他のモデル
serve_model = fn repository_id ->
{:ok, model} =
Bumblebee.load_model({
:hf,
repository_id,
cache_dir: cache_dir
})
{:ok, tokenizer} =
Bumblebee.load_tokenizer({
:hf,
repository_id,
cache_dir: cache_dir
})
Bumblebee.Text.generation(model, tokenizer, max_new_tokens: 10)
end
"gpt2-medium"
|> serve_model.()
|> Nx.Serving.run(text)
|> Map.get(:results)
"gpt2-large"
|> serve_model.()
|> Nx.Serving.run(text)
|> Map.get(:results)