Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Text generation

text_generation.livemd

Text generation

Mix.install(
  [
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.9", override: true},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

設定

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

{:ok, gpt2} = Bumblebee.load_model({:hf, "gpt2", cache_dir: cache_dir})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2", cache_dir: cache_dir})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2", cache_dir: cache_dir})

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 10)

サービスの提供

serving = Bumblebee.Text.generation(gpt2, tokenizer, generation_config)

補完する文章の準備

text_input = Kino.Input.text("TEXT", default: "Robots have gained human rights and")
text = Kino.Input.read(text_input)
serving
|> Nx.Serving.run(text)
|> Map.get(:results)

他のモデル

serve_model = fn repository_id ->
  model_config = {:hf, repository_id, cache_dir: cache_dir}

  {:ok, model} = Bumblebee.load_model(model_config)
  {:ok, tokenizer} = Bumblebee.load_tokenizer(model_config)
  {:ok, generation_config} = Bumblebee.load_generation_config(model_config)
  generation_config = Bumblebee.configure(generation_config, max_new_tokens: 10)

  Bumblebee.Text.generation(model, tokenizer, generation_config)
end
"gpt2-medium"
|> serve_model.()
|> Nx.Serving.run(text)
|> Map.get(:results)
"gpt2-large"
|> serve_model.()
|> Nx.Serving.run(text)
|> Map.get(:results)