Powered by AppSignal & Oban Pro

LLM を Elixir / Livebook で学ぶ 2

02_decoder_generation_with_bumblebee.livemd

LLM を Elixir / Livebook で学ぶ 2

Mix.install(
  [
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.9", override: true},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.15"},
    {:kino_vega_lite, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

事前学習済み GPT 系モデルで「次トークン予測」を体感する

前のノートブックでは、トークン化・埋め込み・Attention を図で見ました。

このノートブックでは、Bumblebee を使って GPT 系モデルを実際に動かしながら、

  • 入力がどうトークンに分かれるか
  • 生成設定を変えると何が変わるか
  • 生成トークン数が増えると時間がどう変わるか

を確認します。

方針

ここでは教材として扱いやすい gpt2 を使います。

  • 仕組みを観察することが目的
  • 長時間学習はしない
  • 例文や解説はこの notebook 用の独自内容

準備

defmodule LLMScratch.Visuals do
  def bar_chart(rows, title, x_field, y_field, opts \\ []) do
    width = Keyword.get(opts, :width, 560)
    height = Keyword.get(opts, :height, 280)

    VegaLite.new(width: width, height: height, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:bar, tooltip: true)
    |> VegaLite.encode_field(:x, Atom.to_string(x_field), type: :nominal, title: Atom.to_string(x_field))
    |> VegaLite.encode_field(:y, Atom.to_string(y_field), type: :quantitative, title: Atom.to_string(y_field))
    |> Kino.VegaLite.new()
  end

  def line_chart(rows, title, x_field, y_field) do
    VegaLite.new(width: 560, height: 280, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:line, point: true, tooltip: true)
    |> VegaLite.encode_field(:x, Atom.to_string(x_field), type: :quantitative, title: Atom.to_string(x_field))
    |> VegaLite.encode_field(:y, Atom.to_string(y_field), type: :quantitative, title: Atom.to_string(y_field))
    |> Kino.VegaLite.new()
  end
end

GPT 系モデルの生成ループ

Kino.Mermaid.new("""
flowchart LR
  A["プロンプト"] --> B["トークン化"]
  B --> C["モデルへ入力"]
  C --> D["次トークンの確率"]
  D --> E["1 トークン選ぶ"]
  E --> F["末尾に追加"]
  F --> C
""")

ポイントは、モデルが一度に文章全体を完成させているわけではなく、1 トークンずつ続きとして足している ことです。

モデルのダウンロード

cache_dir = "/tmp/bumblebee_cache"
repo = {:hf, "gpt2", cache_dir: cache_dir}
{:ok, gpt2} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

1. プロンプトをトークンとして見る

gpt2 は英語向けのモデルなので、まずは短い英語プロンプトで観察します。

prompt_input =
  Kino.Input.textarea("PROMPT",
    default: "The sky above the city looked calm, but"
  )
prompt = Kino.Input.read(prompt_input)
prompt
tokenized = Bumblebee.apply_tokenizer(tokenizer, prompt)
token_ids = tokenized["input_ids"][[0]] |> Nx.to_flat_list()

token_rows =
  token_ids
  |> Enum.with_index()
  |> Enum.map(fn {token_id, position} ->
    %{
      position: position,
      token_id: token_id,
      piece: Bumblebee.Tokenizer.decode(tokenizer, [token_id])
    }
  end)

Kino.DataTable.new(
  token_rows,
  keys: [:position, :token_id, :piece]
)

同じ文章でも、私たちが思う「単語」とモデルの「トークン」は一致しないことがあります。これがサブワード分割の感覚です。

2. プロンプトごとの長さを比べる

トークン数は、そのまま計算量やコンテキスト長の感覚につながります。

prompt_examples = [
  "The sky above the city looked calm, but",
  "Elixir is pleasant for interactive experiments because",
  "A robot entered the library and asked",
  "Write a short poem about rain and memory."
]

prompt_length_rows =
  prompt_examples
  |> Enum.map(fn text ->
    input_ids = Bumblebee.apply_tokenizer(tokenizer, text)["input_ids"][[0]] |> Nx.to_flat_list()

    %{
      prompt: text,
      token_count: length(input_ids)
    }
  end)

Kino.DataTable.new(prompt_length_rows)
LLMScratch.Visuals.bar_chart(prompt_length_rows, "プロンプトごとのトークン数", :prompt, :token_count)

文章の長さが同じでも、トークン数は同じとは限りません。モデルの見え方は「文字数」ではなく「トークン数」に近いです。

3. 生成設定を変えてみる

ここでは 2 つの設定を比べます。

  • greedy: 毎回もっとも確率の高い候補を選びやすい
  • creative: 少しランダム性を入れて広がりを持たせる
greedy_config =
  Bumblebee.configure(generation_config,
    max_new_tokens: 40,
    temperature: 0.7
  )

creative_config =
  Bumblebee.configure(generation_config,
    max_new_tokens: 40,
    temperature: 1.1,
    strategy: %{
      type: :contrastive_search,
      top_k: 40,
      alpha: 0.6
    }
  )

greedy_serving = Bumblebee.Text.generation(gpt2, tokenizer, greedy_config)
creative_serving = Bumblebee.Text.generation(gpt2, tokenizer, creative_config)
greedy_result =
  greedy_serving
  |> Nx.Serving.run(prompt)
  |> Map.get(:results)

creative_result =
  creative_serving
  |> Nx.Serving.run(prompt)
  |> Map.get(:results)

Kino.Layout.tabs([
  greedy: Kino.DataTable.new(greedy_result),
  creative: Kino.DataTable.new(creative_result)
])

初学者向けの観察ポイントは次の通りです。

  • temperature を下げると、出力が安定しやすい
  • temperature を上げると、予測の裾を拾いやすい
  • top_k を入れると、極端に低い候補を切り落としやすい

4. 同じ入力でも生成の感じが変わる

プロンプトの書き方も出力を大きく変えます。

prompt_variants = [
  "The sky above the city looked calm, but",
  "Continue this as a mystery story: The sky above the city looked calm, but",
  "Continue this as a cheerful children's story: The sky above the city looked calm, but"
]

variant_results =
  prompt_variants
  |> Enum.map(fn text ->
    output =
      creative_serving
      |> Nx.Serving.run(text)
      |> Map.get(:results)

    %{
      prompt: text,
      result: inspect(output)
    }
  end)

Kino.DataTable.new(variant_results)

ここで見たいのは、「モデルの知識」だけでなく「指示文の置き方」でも振る舞いが変わる、という点です。

5. 生成トークン数と時間の関係

トークンを 1 つずつ足すので、max_new_tokens を増やすと時間も伸びやすくなります。

benchmark_prompt = "Elixir notebooks are useful because"

timing_rows =
  [10, 20, 40, 60]
  |> Enum.map(fn max_new_tokens ->
    config =
      Bumblebee.configure(generation_config,
        max_new_tokens: max_new_tokens,
        temperature: 0.9
      )

    serving = Bumblebee.Text.generation(gpt2, tokenizer, config)
    {microseconds, _result} = :timer.tc(Nx.Serving, :run, [serving, benchmark_prompt])

    %{
      max_new_tokens: max_new_tokens,
      seconds: microseconds / 1_000_000
    }
  end)

Kino.DataTable.new(timing_rows)
LLMScratch.Visuals.line_chart(timing_rows, "生成トークン数と実行時間", :max_new_tokens, :seconds)

環境によって秒数はかなり変わりますが、傾向としては 出力を長くするほど待ち時間も増える と考えておくと理解しやすいです。

6. 理論とのつながり

前回の内容とつなげると、生成の内部ではざっくり次のことが起きています。

  • 入力文がトークン列になる
  • 各トークンが埋め込みベクトルになる
  • causal self-attention で左側の文脈だけを見る
  • 最後の位置から「次トークン候補」の確率を出す
  • 1 つ選んで末尾に足し、また同じ処理をする

7. まとめ

この notebook の要点

  • GPT 系モデルは、次トークン予測を繰り返して文章を伸ばす
  • トークン数は、計算量とコンテキスト長の感覚に直結する
  • temperaturetop_k で生成の性格をある程度調整できる
  • プロンプトの書き方も出力品質に大きく影響する

次のノートブックでは、モデルを「賢く、役に立つ応答に寄せる」ためのアラインメントの考え方を、軽量な例で学びます。