LLM を Elixir / Livebook で学ぶ 1
Mix.install([
{:nx, "~> 0.9"},
{:kino, "~> 0.15"},
{:kino_vega_lite, "~> 0.1"}
])
トークン・埋め込み・Attention を目で理解する
このノートブックでは、LLM のいちばん大事な土台を Livebook 上で観察します。
今日は次の 4 つに絞ります。
- トークン化すると何が起きるか
- 埋め込みベクトルは何をしているか
- Attention は「どの単語を見るか」をどう決めるか
-
GPT 系モデルで必要な
causal maskは何か
この notebook の立ち位置
元の着想は Transformer / GPT の学習教材から得ていますが、ここで使う説明、図、例文、コードは Livebook 向けに新しく作ったものです。
GPU を使う大規模学習はまだ行わず、まずは「仕組みの見える化」に集中します。
準備
defmodule LLMScratch.Visuals do
def bar_chart(rows, title, x_field, y_field, opts \\ []) do
width = Keyword.get(opts, :width, 520)
height = Keyword.get(opts, :height, 260)
VegaLite.new(width: width, height: height, title: title)
|> VegaLite.data_from_values(rows)
|> VegaLite.mark(:bar, tooltip: true)
|> VegaLite.encode_field(
:x, Atom.to_string(x_field),
type: :nominal,
title: Atom.to_string(x_field)
)
|> VegaLite.encode_field(
:y, Atom.to_string(y_field),
type: :quantitative,
title: Atom.to_string(y_field)
)
|> Kino.VegaLite.new()
end
def line_chart(rows, title, x_field, y_field, opts \\ []) do
width = Keyword.get(opts, :width, 520)
height = Keyword.get(opts, :height, 260)
color_field = Keyword.get(opts, :color_field)
chart =
VegaLite.new(width: width, height: height, title: title)
|> VegaLite.data_from_values(rows)
|> VegaLite.mark(:line, point: true)
|> VegaLite.encode_field(:x, Atom.to_string(x_field),
type: :quantitative,
title: Atom.to_string(x_field)
)
|> VegaLite.encode_field(:y, Atom.to_string(y_field),
type: :quantitative,
title: Atom.to_string(y_field)
)
chart =
if color_field do
VegaLite.encode_field(chart, :color, Atom.to_string(color_field), type: :nominal)
else
chart
end
Kino.VegaLite.new(chart)
end
def heatmap(rows, title, opts \\ []) do
width = Keyword.get(opts, :width, 520)
height = Keyword.get(opts, :height, 260)
VegaLite.new(width: width, height: height, title: title)
|> VegaLite.data_from_values(rows)
|> VegaLite.mark(:rect, tooltip: true)
|> VegaLite.encode_field(
:x, "key", type: :nominal, title: "注目されるトークン",
sort: [field: "index_k"]
)
|> VegaLite.encode_field(
:y, "query", type: :nominal, title: "見ているトークン",
sort: [field: "index_q"]
)
|> VegaLite.encode_field(
:color, "score", type: :quantitative, scale: [scheme: "blues"])
|> Kino.VegaLite.new()
end
end
まず全体像
Kino.Mermaid.new("""
flowchart LR
A["文章"] --> B["トークン化"]
B --> C["トークン ID"]
C --> D["埋め込みベクトル"]
D --> E["位置情報を加える"]
E --> F["Attention"]
F --> G["次トークンの予測"]
""")
大きな流れはとても単純です。
- 文章を小さな部品に分ける
- 部品を数値 ID に変える
- 数値 ID をベクトルに変える
- どのトークンをどれだけ見るかを Attention で決める
- 最後に「次に出そうなトークン」を予測する
1. トークン化
ここでは説明しやすいように、空白で区切られた簡単な文を使います。
sentence = "ねこ は ひるね が すき"
tokens = String.split(sentence, " ")
vocab =
tokens
|> Enum.uniq()
|> Enum.sort()
token_to_id =
vocab
|> Enum.with_index()
|> Map.new()
id_to_token =
token_to_id
|> Enum.map(fn {token, id} -> {id, token} end)
|> Map.new()
token_ids = Enum.map(tokens, &token_to_id[&1])
%{
vocab: vocab,
token_to_id: token_to_id,
token_ids: token_ids
}
token_rows =
tokens
|> Enum.with_index()
|> Enum.map(fn {token, position} ->
%{
position: position,
token: token,
token_id: token_to_id[token]
}
end)
Kino.DataTable.new(token_rows)
LLM にとって、最初から「ねこ」の意味が分かっているわけではありません。まずは単に ID に置き換えて、計算できる形に直しています。
2. one-hot 表現から埋め込みへ
one-hot は「その語がどこにあるか」だけを表すベクトルです。
vocab_size = length(vocab)
sequence_one_hot =
token_ids
|> Enum.map(fn id ->
Nx.equal(Nx.iota({vocab_size}), Nx.tensor(id))
end)
|> Nx.stack()
one-hot は便利ですが、語と語の近さを表しにくいです。そこで LLM では、より密な実数ベクトルである埋め込みを使います。
ここでは説明用に、2 次元の手作り埋め込みを置きます。
embedding_map = %{
"ねこ" => [0.90, 0.15],
"は" => [0.10, 0.05],
"ひるね" => [0.95, 0.85],
"が" => [0.12, 0.10],
"すき" => [0.88, 0.70]
}
embedding_rows =
vocab
|> Enum.map(fn token ->
[x, y] = embedding_map[token]
%{
token: token,
dim_1: x,
dim_2: y
}
end)
Kino.DataTable.new(
embedding_rows,
keys: [:token, :dim_1, :dim_2]
)
embedding_tensor =
tokens
|> Enum.map(&embedding_map[&1])
|> Nx.tensor(type: {:f, 32})
embedding_tensor
埋め込みが学習されると、「似た働きをするトークン」や「一緒に現れやすいトークン」が近いベクトルになることがあります。
3. Attention の最小例
Attention は「今見ているトークンが、どのトークンをどれだけ参照するか」を決める仕組みです。
今回は すき が、文中のどこを見そうかを簡単な内積で観察します。
softmax_1d = fn scores ->
exps = Nx.exp(Nx.subtract(scores, Nx.reduce_max(scores)))
Nx.divide(exps, Nx.sum(exps))
end
scaled_dot_attention = fn query, keys, values ->
d_model = keys |> Nx.shape() |> elem(1)
scores = Nx.divide(Nx.dot(keys, [1], query, [0]), :math.sqrt(d_model))
weights = softmax_1d.(scores)
output = Nx.dot(weights, [0], values, [0])
{output, weights}
end
query_index = Enum.find_index(tokens, &(&1 == "すき"))
query = embedding_tensor[query_index]
{attention_output, attention_weights} =
scaled_dot_attention.(query, embedding_tensor, embedding_tensor)
%{
query_token: Enum.at(tokens, query_index),
output_vector: attention_output,
weights: attention_weights
}
weight_rows =
tokens
|> Enum.with_index()
|> Enum.map(fn {token, index} ->
%{
token: token,
weight: Nx.to_number(attention_weights[index])
}
end)
LLMScratch.Visuals.bar_chart(weight_rows, "`すき` がどのトークンを見るか", :token, :weight)
この例では、すき と意味的に近く置いた ひるね や ねこ への重みが大きくなりやすくなります。
もちろん実際の LLM では、埋め込みは人手ではなく学習で作られます。
4. なぜスケーリングが必要か
埋め込み次元が大きいと、内積の値が大きくなりやすく、softmax が極端に尖ることがあります。
:rand.seed(:exsss, {11, 22, 33})
large_scores =
1..12
|> Enum.map(fn _ -> :rand.normal() * 8 end)
small_scores =
1..12
|> Enum.map(fn _ -> :rand.normal() end)
large_softmax = softmax_1d.(Nx.tensor(large_scores))
small_softmax = softmax_1d.(Nx.tensor(small_scores))
score_rows =
Enum.with_index(large_scores)
|> Enum.map(fn {score, index} ->
%{
index: index,
large_scale: score,
large_softmax: Nx.to_number(large_softmax[index]),
small_softmax: Nx.to_number(small_softmax[index])
}
end)
Kino.DataTable.new(score_rows)
softmax_rows =
score_rows
|> Enum.flat_map(fn row ->
[
%{index: row.index, value: row.large_softmax, series: "値が大きいとき"},
%{index: row.index, value: row.small_softmax, series: "値が小さいとき"}
]
end)
LLMScratch.Visuals.line_chart(
softmax_rows,
"softmax の尖り方の比較",
:index,
:value,
color_field: :series
)
この偏りを少しやわらげるために、Transformer では通常 sqrt(d_k) で割る scaled dot-product attention を使います。
5. Causal Mask を可視化する
GPT のような decoder-only モデルでは、未来の単語を先読みしてはいけません。
たとえば 3 番目のトークンは、4 番目や 5 番目を見てはいけません。これを実現するのが causal mask です。
decoder_tokens = ["わたし", "は", "きょう", "本", "を", "読む"]
decoder_embedding_map = %{
"わたし" => [0.80, 0.10, 0.20],
"は" => [0.10, 0.05, 0.05],
"きょう" => [0.55, 0.85, 0.20],
"本" => [0.92, 0.70, 0.15],
"を" => [0.12, 0.12, 0.02],
"読む" => [0.75, 0.88, 0.92]
}
decoder_embeddings =
decoder_tokens
|> Enum.map(&decoder_embedding_map[&1])
|> Nx.tensor(type: {:f, 32})
decoder_embeddings
causal_mask_rows =
for {query, i} <- Enum.with_index(decoder_tokens),
{key, j} <- Enum.with_index(decoder_tokens) do
%{
index_q: i,
index_k: j,
query: query,
key: key,
score: if(j > i, do: 1.0, else: 0.0)
}
end
LLMScratch.Visuals.heatmap(causal_mask_rows, "1 の場所は未来なので見てはいけない")
row_softmax = fn matrix ->
exps =
Nx.exp(
Nx.subtract(matrix, Nx.reduce_max(matrix, axes: [1], keep_axes: true))
)
Nx.divide(exps, Nx.sum(exps, axes: [1], keep_axes: true))
end
score_matrix =
Nx.divide(
Nx.dot(decoder_embeddings, [1], Nx.transpose(decoder_embeddings), [0]),
:math.sqrt(3)
)
mask_tensor =
for i <- 0..(length(decoder_tokens) - 1) do
for j <- 0..(length(decoder_tokens) - 1) do
j > i
end
end
|> Nx.tensor(type: {:u, 8})
masked_scores =
Nx.select(mask_tensor, Nx.broadcast(-1.0e9, Nx.shape(score_matrix)), score_matrix)
masked_weights = row_softmax.(masked_scores)
masked_weight_rows =
for {query, i} <- Enum.with_index(decoder_tokens),
{key, j} <- Enum.with_index(decoder_tokens) do
%{
query: query,
key: key,
score: Nx.to_number(masked_weights[[i, j]])
}
end
LLMScratch.Visuals.heatmap(masked_weight_rows, "mask 適用後の attention 重み")
ヒートマップを見ると、各トークンは自分より右側を見ていません。これが「次トークン予測」が成立する理由のひとつです。
6. 位置情報はどう入るのか
Attention だけだと、単語の並び順が弱くなります。そこで位置情報を加えます。
ここでは固定の sin / cos 位置エンコーディングを小さく作ってみます。
positional_encoding = fn max_position, d_model ->
for pos <- 0..(max_position - 1) do
for dim <- 0..(d_model - 1) do
angle = pos / :math.pow(10_000, (2 * div(dim, 2)) / d_model)
if rem(dim, 2) == 0 do
:math.sin(angle)
else
:math.cos(angle)
end
end
end
end
pe_matrix = positional_encoding.(12, 8)
pe_matrix |> Enum.take(4)
pe_rows =
pe_matrix
|> Enum.with_index()
|> Enum.flat_map(fn {row, position} ->
row
|> Enum.take(4)
|> Enum.with_index()
|> Enum.map(fn {value, dim} ->
%{
position: position,
value: value,
dimension: "dim_#{dim}"
}
end)
end)
LLMScratch.Visuals.line_chart(
pe_rows,
"位置エンコーディングの波形",
:position,
:value,
color_field: :dimension
)
位置ごとに違う波形を足すことで、モデルは「同じ単語でも何番目にあるか」を区別しやすくなります。
7. GPT 系モデルでは何が起きているか
Kino.Mermaid.new("""
flowchart TD
A["入力トークン"] --> B["Token Embedding"]
A --> C["Position Embedding"]
B --> D["足し合わせる"]
C --> D
D --> E["Causal Self-Attention"]
E --> F["MLP"]
F --> G["Linear"]
G --> H["次トークンの確率"]
""")
実際の GPT はもっと深いネットワークですが、基本の考え方はこの繰り返しです。
- 埋め込みで数値化する
- 自分より前の文脈を Attention で読む
- 各位置で次トークンの確率を出す
8. ここまでのまとめ
今日の要点
- トークン化は、文章を計算可能な部品へ分解する作業
- 埋め込みは、トークンを意味のある実数ベクトルへ変換する層
- Attention は、各トークンがどこを見るべきかを重みとして学ぶ仕組み
- GPT では causal mask によって未来のトークンを見ない
- 位置エンコーディングで単語の順序情報を補う
次のノートブックでは、ここで見た考え方が実際の GPT 系モデルの生成にどう現れるかを Bumblebee で触っていきます。