Powered by AppSignal & Oban Pro

LLM を Elixir / Livebook で学ぶ 4

04_tiny_bigram_language_model_training.livemd

LLM を Elixir / Livebook で学ぶ 4

Mix.install(
  [
    {:nx, "~> 0.9"},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.15"},
    {:kino_vega_lite, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

発展編: 極小の言語モデルを自分で学習する

ここでは、LLM の考え方をさらに一歩進めて、next-token prediction を実際に学習してみます。

ただし、いきなり Transformer 全体を作るのではなく、まずは 1 個前のトークンだけを見る bigram 言語モデル を使います。

この notebook で分かること:

  • 学習データはどんなペアに分解されるか
  • 損失が下がると、次トークン予測がどう変わるか
  • 小さなモデルでも「学習前」と「学習後」で生成が変わること
  • それでも bigram には大きな限界があること

この notebook の位置づけ

これは Transformer の代わり ではありません。むしろ逆で、

  • 「学習とは何か」を手で追いやすくする
  • Attention を入れる前の最小の next-token 学習を理解する

ための導入用モデルです。

準備

defmodule LLMScratch.Visuals do
  def line_chart(rows, title, x_field, y_field, opts \\ []) do
    width = Keyword.get(opts, :width, 620)
    height = Keyword.get(opts, :height, 280)

    VegaLite.new(width: width, height: height, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:line, point: true, tooltip: true)
    |> VegaLite.encode_field(
      :x, Atom.to_string(x_field),
      type: :quantitative,
      title: Atom.to_string(x_field)
    )
    |> VegaLite.encode_field(
      :y, Atom.to_string(y_field),
      type: :quantitative,
      title: Atom.to_string(y_field))
    |> Kino.VegaLite.new()
  end

  def heatmap(rows, title, opts \\ []) do
    width = Keyword.get(opts, :width, 620)
    height = Keyword.get(opts, :height, 360)

    VegaLite.new(width: width, height: height, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:rect, tooltip: true)
    |> VegaLite.encode_field(
      :x, "target", type: :nominal, title: "次トークン"
    )
    |> VegaLite.encode_field(
      :y, "source", type: :nominal, title: "現在トークン"
    )
    |> VegaLite.encode_field(
      :color, "score", type: :quantitative, scale: [scheme: "tealblues"]
    )
    |> Kino.VegaLite.new()
  end
end

defmodule TinyBigram do
  import Nx.Defn

  defn row_softmax(logits) do
    shifted = Nx.subtract(logits, Nx.reduce_max(logits, axes: [1], keep_axes: true))
    exps = Nx.exp(shifted)
    Nx.divide(exps, Nx.sum(exps, axes: [1], keep_axes: true))
  end

  defn loss(weights, input_ids, target_one_hot) do
    logits = Nx.take(weights, input_ids)
    probs = row_softmax(logits)

    target_one_hot
    |> Nx.multiply(Nx.log(Nx.add(probs, 1.0e-9)))
    |> Nx.sum(axes: [1])
    |> Nx.negate()
    |> Nx.mean()
  end

  defn update(weights, input_ids, target_one_hot, learning_rate) do
    grads = grad(weights, &loss(&1, input_ids, target_one_hot))
    weights - grads * learning_rate
  end

  defn probabilities(weights) do
    row_softmax(weights)
  end
end

1. 小さな学習コーパスを作る

今回は説明しやすいように、空白区切りの短い日本語文を小さく用意します。

corpus_sentences = [
  "ねこ は ひるね が すき",
  "ねこ は さかな が すき",
  "ねこ は まどべ で ひるね",
  "いぬ は さんぽ が すき",
  "いぬ は にわ を かける",
  "とり は そら を とぶ"
]

Kino.DataTable.new(Enum.map(corpus_sentences, &%{sentence: &1}))

2. 語彙とトークン ID

文のはじめと終わりが分かるように、<BOS><EOS> も入れます。

special_tokens = ["<BOS>", "<EOS>"]

vocab =
  corpus_sentences
  |> Enum.flat_map(&String.split(&1, " "))
  |> Kernel.++(special_tokens)
  |> Enum.uniq()
  |> Enum.sort()

token_to_id =
  vocab
  |> Enum.with_index()
  |> Map.new()

id_to_token =
  token_to_id
  |> Enum.map(fn {token, id} -> {id, token} end)
  |> Map.new()

%{
  vocab_size: length(vocab),
  vocab: vocab
}
token_rows =
  vocab
  |> Enum.with_index()
  |> Enum.map(fn {token, id} ->
    %{
      token: token,
      token_id: id
    }
  end)

Kino.DataTable.new(token_rows)

3. next-token 学習用のペアへ分解する

bigram モデルでは、現在のトークン -> 次のトークン のペアを学習します。

token_sequences =
  corpus_sentences
  |> Enum.map(fn sentence ->
    ["<BOS>"] ++ String.split(sentence, " ") ++ ["<EOS>"]
  end)

Kino.DataTable.new(Enum.map(token_sequences, &%{sequence: Enum.join(&1, " ")}))
training_pairs =
  token_sequences
  |> Enum.flat_map(fn sequence ->
    sequence
    |> Enum.chunk_every(2, 1, :discard)
    |> Enum.map(fn [source, target] ->
      %{
        source: source,
        target: target,
        source_id: token_to_id[source],
        target_id: token_to_id[target]
      }
    end)
  end)

Kino.DataTable.new(training_pairs)

4. 観察用に、出現回数の bigram 行列も作る

学習済みモデルの重みが、どのくらいこの分布に近づくかを見るための比較材料です。

pair_counts =
  training_pairs
  |> Enum.frequencies_by(fn pair -> {pair.source, pair.target} end)

count_rows =
  for source <- vocab, target <- vocab do
    %{
      source: source,
      target: target,
      score: Map.get(pair_counts, {source, target}, 0)
    }
  end

LLMScratch.Visuals.heatmap(count_rows, "学習コーパスの bigram 出現回数")

5. モデルに渡す tensor を作る

input_ids =
  training_pairs
  |> Enum.map(& &1.source_id)
  |> Nx.tensor(type: :s64)

target_ids =
  training_pairs
  |> Enum.map(& &1.target_id)
  |> Nx.tensor(type: :s64)

vocab_size = length(vocab)

target_one_hot =
  Nx.equal(Nx.new_axis(target_ids, 1), Nx.iota({vocab_size}))
  |> Nx.as_type(:f32)

%{
  input_shape: Nx.shape(input_ids),
  target_shape: Nx.shape(target_ids),
  target_one_hot_shape: Nx.shape(target_one_hot),
  vocab_size: vocab_size
}

6. 学習前の状態を見る

bigram モデルのパラメータは、[語彙数 x 語彙数] の行列です。

行は「今のトークン」、列は「次トークン候補」を表します。

:rand.seed(:exsss, {123, 456, 789})

initial_weights =
  1..(vocab_size * vocab_size)
  |> Enum.map(fn _ -> :rand.normal() * 0.01 end)
  |> Nx.tensor(type: {:f, 32})
  |> Nx.reshape({vocab_size, vocab_size})

Nx.shape(initial_weights)
initial_loss =
  TinyBigram.loss(initial_weights, input_ids, target_one_hot)
  |> Nx.to_number()

%{initial_loss: initial_loss}
summarize_predictions = fn probabilities ->
  vocab
  |> Enum.with_index()
  |> Enum.map(fn {source, source_id} ->
    top_predictions =
      probabilities[source_id]
      |> Nx.to_flat_list()
      |> Enum.zip(vocab)
      |> Enum.map(fn {prob, token} -> {token, prob} end)
      |> Enum.sort_by(fn {_token, prob} -> prob end, :desc)
      |> Enum.take(3)
      |> Enum.map(fn {token, prob} -> "#{token}(#{Float.round(prob, 3)})" end)
      |> Enum.join(" / ")

    %{
      source: source,
      top3: top_predictions
    }
  end)
end

initial_probabilities = TinyBigram.probabilities(initial_weights)
Kino.DataTable.new(summarize_predictions.(initial_probabilities))

学習前はほぼランダムなので、ねこ の次に何が来るかもまだ分かっていません。

7. 学習を回す

今回は小さな全データをまとめて使う、フルバッチ学習にします。

epochs = 400
learning_rate = 1.2e-1

{trained_weights, loss_rows} =
  Enum.reduce(0..epochs, {initial_weights, []}, fn epoch, {weights, acc} ->
    loss_value =
      TinyBigram.loss(weights, input_ids, target_one_hot)
      |> Nx.to_number()

    next_weights =
      TinyBigram.update(weights, input_ids, target_one_hot, learning_rate)

    row = %{epoch: epoch, loss: loss_value}

    {next_weights, [row | acc]}
  end)

loss_rows = Enum.reverse(loss_rows)

%{
  first_loss: hd(loss_rows).loss,
  last_loss: List.last(loss_rows).loss
}
LLMScratch.Visuals.line_chart(loss_rows, "学習中の損失", :epoch, :loss)

8. 学習後の予測を見る

trained_probabilities = TinyBigram.probabilities(trained_weights)
Kino.DataTable.new(summarize_predictions.(trained_probabilities))
trained_rows =
  for {source, i} <- Enum.with_index(vocab),
      {target, j} <- Enum.with_index(vocab) do
    %{
      source: source,
      target: target,
      score: Float.round(Nx.to_number(trained_probabilities[[i, j]]), 4)
    }
  end

LLMScratch.Visuals.heatmap(trained_rows, "学習後の next-token 確率")

<BOS> の行では、文頭に来やすい語へ確率が集まります。 の行では、その後ろに来た語の傾向が見えるはずです。

9. 学習前後で生成を比べる

bigram なので、毎回「今の 1 語」だけを見て次トークンを選びます。

greedy_generate = fn probabilities, max_steps ->
  next_token = fn source ->
    source_id = token_to_id[source]

    probabilities[source_id]
    |> Nx.to_flat_list()
    |> Enum.zip(vocab)
    |> Enum.map(fn {prob, token} -> {token, prob} end)
    |> Enum.reject(fn {token, _prob} -> token == "<BOS>" end)
    |> Enum.max_by(fn {_token, prob} -> prob end)
    |> elem(0)
  end

  1..max_steps
  |> Enum.reduce_while(["<BOS>"], fn _, acc ->
    source = List.last(acc)
    target = next_token.(source)

    if target == "<EOS>" do
      {:halt, acc ++ [target]}
    else
      {:cont, acc ++ [target]}
    end
  end)
end

format_generated = fn tokens ->
  tokens
  |> Enum.reject(&(&1 in ["<BOS>", "<EOS>"]))
  |> Enum.join(" ")
end

generation_rows = [
  %{
    stage: "before training",
    generated: format_generated.(greedy_generate.(initial_probabilities, 8))
  },
  %{
    stage: "after training",
    generated: format_generated.(greedy_generate.(trained_probabilities, 8))
  }
]

Kino.DataTable.new(
  generation_rows,
  keys: [:stage, :generated]
)

学習前は不自然な並びになりやすく、学習後はコーパスに寄った語順へ近づきます。

10. それでも bigram には限界がある

ここが大事です。

このモデルは「直前の 1 語」しか見ません。だから、

  • ねこ は のあとに何が来るか
  • もっと前に出てきた主語や話題
  • 長い依存関係

をほとんど扱えません。

LLM が Transformer を使うのは、こうした 長い文脈を読む必要 があるからです。

11. まとめ

この notebook の要点

  • next-token 学習は、現在トークン -> 次トークン の予測問題として書ける
  • 損失が下がると、次トークン確率の分布がデータに沿って整っていく
  • ごく小さいモデルでも、学習前後の生成差ははっきり見える
  • ただし bigram は 1 語しか見ないので、LLM の表現力には遠い

次の発展としては、

  • 2 語以上の文脈を見る trigram / 小さな MLP
  • Axon で埋め込み層つきのモデルを組む
  • 最終的に causal self-attention を入れる

という順で進むと理解しやすいです。