LLM を Elixir / Livebook で学ぶ 4
Mix.install(
[
{:nx, "~> 0.9"},
{:exla, "~> 0.9"},
{:kino, "~> 0.15"},
{:kino_vega_lite, "~> 0.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
発展編: 極小の言語モデルを自分で学習する
ここでは、LLM の考え方をさらに一歩進めて、next-token prediction を実際に学習してみます。
ただし、いきなり Transformer 全体を作るのではなく、まずは 1 個前のトークンだけを見る bigram 言語モデル を使います。
この notebook で分かること:
- 学習データはどんなペアに分解されるか
- 損失が下がると、次トークン予測がどう変わるか
- 小さなモデルでも「学習前」と「学習後」で生成が変わること
- それでも bigram には大きな限界があること
この notebook の位置づけ
これは Transformer の代わり ではありません。むしろ逆で、
- 「学習とは何か」を手で追いやすくする
- Attention を入れる前の最小の next-token 学習を理解する
ための導入用モデルです。
準備
defmodule LLMScratch.Visuals do
def line_chart(rows, title, x_field, y_field, opts \\ []) do
width = Keyword.get(opts, :width, 620)
height = Keyword.get(opts, :height, 280)
VegaLite.new(width: width, height: height, title: title)
|> VegaLite.data_from_values(rows)
|> VegaLite.mark(:line, point: true, tooltip: true)
|> VegaLite.encode_field(
:x, Atom.to_string(x_field),
type: :quantitative,
title: Atom.to_string(x_field)
)
|> VegaLite.encode_field(
:y, Atom.to_string(y_field),
type: :quantitative,
title: Atom.to_string(y_field))
|> Kino.VegaLite.new()
end
def heatmap(rows, title, opts \\ []) do
width = Keyword.get(opts, :width, 620)
height = Keyword.get(opts, :height, 360)
VegaLite.new(width: width, height: height, title: title)
|> VegaLite.data_from_values(rows)
|> VegaLite.mark(:rect, tooltip: true)
|> VegaLite.encode_field(
:x, "target", type: :nominal, title: "次トークン"
)
|> VegaLite.encode_field(
:y, "source", type: :nominal, title: "現在トークン"
)
|> VegaLite.encode_field(
:color, "score", type: :quantitative, scale: [scheme: "tealblues"]
)
|> Kino.VegaLite.new()
end
end
defmodule TinyBigram do
import Nx.Defn
defn row_softmax(logits) do
shifted = Nx.subtract(logits, Nx.reduce_max(logits, axes: [1], keep_axes: true))
exps = Nx.exp(shifted)
Nx.divide(exps, Nx.sum(exps, axes: [1], keep_axes: true))
end
defn loss(weights, input_ids, target_one_hot) do
logits = Nx.take(weights, input_ids)
probs = row_softmax(logits)
target_one_hot
|> Nx.multiply(Nx.log(Nx.add(probs, 1.0e-9)))
|> Nx.sum(axes: [1])
|> Nx.negate()
|> Nx.mean()
end
defn update(weights, input_ids, target_one_hot, learning_rate) do
grads = grad(weights, &loss(&1, input_ids, target_one_hot))
weights - grads * learning_rate
end
defn probabilities(weights) do
row_softmax(weights)
end
end
1. 小さな学習コーパスを作る
今回は説明しやすいように、空白区切りの短い日本語文を小さく用意します。
corpus_sentences = [
"ねこ は ひるね が すき",
"ねこ は さかな が すき",
"ねこ は まどべ で ひるね",
"いぬ は さんぽ が すき",
"いぬ は にわ を かける",
"とり は そら を とぶ"
]
Kino.DataTable.new(Enum.map(corpus_sentences, &%{sentence: &1}))
2. 語彙とトークン ID
文のはじめと終わりが分かるように、<BOS> と <EOS> も入れます。
special_tokens = ["<BOS>", "<EOS>"]
vocab =
corpus_sentences
|> Enum.flat_map(&String.split(&1, " "))
|> Kernel.++(special_tokens)
|> Enum.uniq()
|> Enum.sort()
token_to_id =
vocab
|> Enum.with_index()
|> Map.new()
id_to_token =
token_to_id
|> Enum.map(fn {token, id} -> {id, token} end)
|> Map.new()
%{
vocab_size: length(vocab),
vocab: vocab
}
token_rows =
vocab
|> Enum.with_index()
|> Enum.map(fn {token, id} ->
%{
token: token,
token_id: id
}
end)
Kino.DataTable.new(token_rows)
3. next-token 学習用のペアへ分解する
bigram モデルでは、現在のトークン -> 次のトークン のペアを学習します。
token_sequences =
corpus_sentences
|> Enum.map(fn sentence ->
["<BOS>"] ++ String.split(sentence, " ") ++ ["<EOS>"]
end)
Kino.DataTable.new(Enum.map(token_sequences, &%{sequence: Enum.join(&1, " ")}))
training_pairs =
token_sequences
|> Enum.flat_map(fn sequence ->
sequence
|> Enum.chunk_every(2, 1, :discard)
|> Enum.map(fn [source, target] ->
%{
source: source,
target: target,
source_id: token_to_id[source],
target_id: token_to_id[target]
}
end)
end)
Kino.DataTable.new(training_pairs)
4. 観察用に、出現回数の bigram 行列も作る
学習済みモデルの重みが、どのくらいこの分布に近づくかを見るための比較材料です。
pair_counts =
training_pairs
|> Enum.frequencies_by(fn pair -> {pair.source, pair.target} end)
count_rows =
for source <- vocab, target <- vocab do
%{
source: source,
target: target,
score: Map.get(pair_counts, {source, target}, 0)
}
end
LLMScratch.Visuals.heatmap(count_rows, "学習コーパスの bigram 出現回数")
5. モデルに渡す tensor を作る
input_ids =
training_pairs
|> Enum.map(& &1.source_id)
|> Nx.tensor(type: :s64)
target_ids =
training_pairs
|> Enum.map(& &1.target_id)
|> Nx.tensor(type: :s64)
vocab_size = length(vocab)
target_one_hot =
Nx.equal(Nx.new_axis(target_ids, 1), Nx.iota({vocab_size}))
|> Nx.as_type(:f32)
%{
input_shape: Nx.shape(input_ids),
target_shape: Nx.shape(target_ids),
target_one_hot_shape: Nx.shape(target_one_hot),
vocab_size: vocab_size
}
6. 学習前の状態を見る
bigram モデルのパラメータは、[語彙数 x 語彙数] の行列です。
行は「今のトークン」、列は「次トークン候補」を表します。
:rand.seed(:exsss, {123, 456, 789})
initial_weights =
1..(vocab_size * vocab_size)
|> Enum.map(fn _ -> :rand.normal() * 0.01 end)
|> Nx.tensor(type: {:f, 32})
|> Nx.reshape({vocab_size, vocab_size})
Nx.shape(initial_weights)
initial_loss =
TinyBigram.loss(initial_weights, input_ids, target_one_hot)
|> Nx.to_number()
%{initial_loss: initial_loss}
summarize_predictions = fn probabilities ->
vocab
|> Enum.with_index()
|> Enum.map(fn {source, source_id} ->
top_predictions =
probabilities[source_id]
|> Nx.to_flat_list()
|> Enum.zip(vocab)
|> Enum.map(fn {prob, token} -> {token, prob} end)
|> Enum.sort_by(fn {_token, prob} -> prob end, :desc)
|> Enum.take(3)
|> Enum.map(fn {token, prob} -> "#{token}(#{Float.round(prob, 3)})" end)
|> Enum.join(" / ")
%{
source: source,
top3: top_predictions
}
end)
end
initial_probabilities = TinyBigram.probabilities(initial_weights)
Kino.DataTable.new(summarize_predictions.(initial_probabilities))
学習前はほぼランダムなので、ねこ の次に何が来るかもまだ分かっていません。
7. 学習を回す
今回は小さな全データをまとめて使う、フルバッチ学習にします。
epochs = 400
learning_rate = 1.2e-1
{trained_weights, loss_rows} =
Enum.reduce(0..epochs, {initial_weights, []}, fn epoch, {weights, acc} ->
loss_value =
TinyBigram.loss(weights, input_ids, target_one_hot)
|> Nx.to_number()
next_weights =
TinyBigram.update(weights, input_ids, target_one_hot, learning_rate)
row = %{epoch: epoch, loss: loss_value}
{next_weights, [row | acc]}
end)
loss_rows = Enum.reverse(loss_rows)
%{
first_loss: hd(loss_rows).loss,
last_loss: List.last(loss_rows).loss
}
LLMScratch.Visuals.line_chart(loss_rows, "学習中の損失", :epoch, :loss)
8. 学習後の予測を見る
trained_probabilities = TinyBigram.probabilities(trained_weights)
Kino.DataTable.new(summarize_predictions.(trained_probabilities))
trained_rows =
for {source, i} <- Enum.with_index(vocab),
{target, j} <- Enum.with_index(vocab) do
%{
source: source,
target: target,
score: Float.round(Nx.to_number(trained_probabilities[[i, j]]), 4)
}
end
LLMScratch.Visuals.heatmap(trained_rows, "学習後の next-token 確率")
<BOS> の行では、文頭に来やすい語へ確率が集まります。は の行では、その後ろに来た語の傾向が見えるはずです。
9. 学習前後で生成を比べる
bigram なので、毎回「今の 1 語」だけを見て次トークンを選びます。
greedy_generate = fn probabilities, max_steps ->
next_token = fn source ->
source_id = token_to_id[source]
probabilities[source_id]
|> Nx.to_flat_list()
|> Enum.zip(vocab)
|> Enum.map(fn {prob, token} -> {token, prob} end)
|> Enum.reject(fn {token, _prob} -> token == "<BOS>" end)
|> Enum.max_by(fn {_token, prob} -> prob end)
|> elem(0)
end
1..max_steps
|> Enum.reduce_while(["<BOS>"], fn _, acc ->
source = List.last(acc)
target = next_token.(source)
if target == "<EOS>" do
{:halt, acc ++ [target]}
else
{:cont, acc ++ [target]}
end
end)
end
format_generated = fn tokens ->
tokens
|> Enum.reject(&(&1 in ["<BOS>", "<EOS>"]))
|> Enum.join(" ")
end
generation_rows = [
%{
stage: "before training",
generated: format_generated.(greedy_generate.(initial_probabilities, 8))
},
%{
stage: "after training",
generated: format_generated.(greedy_generate.(trained_probabilities, 8))
}
]
Kino.DataTable.new(
generation_rows,
keys: [:stage, :generated]
)
学習前は不自然な並びになりやすく、学習後はコーパスに寄った語順へ近づきます。
10. それでも bigram には限界がある
ここが大事です。
このモデルは「直前の 1 語」しか見ません。だから、
-
ねこ はのあとに何が来るか - もっと前に出てきた主語や話題
- 長い依存関係
をほとんど扱えません。
LLM が Transformer を使うのは、こうした 長い文脈を読む必要 があるからです。
11. まとめ
この notebook の要点
-
next-token 学習は、
現在トークン -> 次トークンの予測問題として書ける - 損失が下がると、次トークン確率の分布がデータに沿って整っていく
- ごく小さいモデルでも、学習前後の生成差ははっきり見える
- ただし bigram は 1 語しか見ないので、LLM の表現力には遠い
次の発展としては、
- 2 語以上の文脈を見る trigram / 小さな MLP
-
Axonで埋め込み層つきのモデルを組む - 最終的に causal self-attention を入れる
という順で進むと理解しやすいです。