Powered by AppSignal & Oban Pro

LLM を Elixir / Livebook で学ぶ 1

01_tokens_embeddings_attention.livemd

LLM を Elixir / Livebook で学ぶ 1

Mix.install([
  {:nx, "~> 0.9"},
  {:kino, "~> 0.15"},
  {:kino_vega_lite, "~> 0.1"}
])

トークン・埋め込み・Attention を目で理解する

このノートブックでは、LLM のいちばん大事な土台を Livebook 上で観察します。

今日は次の 4 つに絞ります。

  • トークン化すると何が起きるか
  • 埋め込みベクトルは何をしているか
  • Attention は「どの単語を見るか」をどう決めるか
  • GPT 系モデルで必要な causal mask は何か

この notebook の立ち位置

元の着想は Transformer / GPT の学習教材から得ていますが、ここで使う説明、図、例文、コードは Livebook 向けに新しく作ったものです。

GPU を使う大規模学習はまだ行わず、まずは「仕組みの見える化」に集中します。

準備

defmodule LLMScratch.Visuals do
  def bar_chart(rows, title, x_field, y_field, opts \\ []) do
    width = Keyword.get(opts, :width, 520)
    height = Keyword.get(opts, :height, 260)

    VegaLite.new(width: width, height: height, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:bar, tooltip: true)
    |> VegaLite.encode_field(
      :x, Atom.to_string(x_field),
      type: :nominal,
      title: Atom.to_string(x_field)
    )
    |> VegaLite.encode_field(
      :y, Atom.to_string(y_field),
      type: :quantitative,
      title: Atom.to_string(y_field)
    )
    |> Kino.VegaLite.new()
  end

  def line_chart(rows, title, x_field, y_field, opts \\ []) do
    width = Keyword.get(opts, :width, 520)
    height = Keyword.get(opts, :height, 260)
    color_field = Keyword.get(opts, :color_field)

    chart =
      VegaLite.new(width: width, height: height, title: title)
      |> VegaLite.data_from_values(rows)
      |> VegaLite.mark(:line, point: true)
      |> VegaLite.encode_field(:x, Atom.to_string(x_field),
        type: :quantitative,
        title: Atom.to_string(x_field)
      )
      |> VegaLite.encode_field(:y, Atom.to_string(y_field),
        type: :quantitative,
        title: Atom.to_string(y_field)
      )

    chart =
      if color_field do
        VegaLite.encode_field(chart, :color, Atom.to_string(color_field), type: :nominal)
      else
        chart
      end

    Kino.VegaLite.new(chart)
  end

  def heatmap(rows, title, opts \\ []) do
    width = Keyword.get(opts, :width, 520)
    height = Keyword.get(opts, :height, 260)

    VegaLite.new(width: width, height: height, title: title)
    |> VegaLite.data_from_values(rows)
    |> VegaLite.mark(:rect, tooltip: true)
    |> VegaLite.encode_field(
      :x, "key", type: :nominal, title: "注目されるトークン",
      sort: [field: "index_k"]
    )
    |> VegaLite.encode_field(
      :y, "query", type: :nominal, title: "見ているトークン",
      sort: [field: "index_q"]
    )
    |> VegaLite.encode_field(
      :color, "score", type: :quantitative, scale: [scheme: "blues"])
    |> Kino.VegaLite.new()
  end
end

まず全体像

Kino.Mermaid.new("""
flowchart LR
  A["文章"] --> B["トークン化"]
  B --> C["トークン ID"]
  C --> D["埋め込みベクトル"]
  D --> E["位置情報を加える"]
  E --> F["Attention"]
  F --> G["次トークンの予測"]
""")

大きな流れはとても単純です。

  1. 文章を小さな部品に分ける
  2. 部品を数値 ID に変える
  3. 数値 ID をベクトルに変える
  4. どのトークンをどれだけ見るかを Attention で決める
  5. 最後に「次に出そうなトークン」を予測する

1. トークン化

ここでは説明しやすいように、空白で区切られた簡単な文を使います。

sentence = "ねこ は ひるね が すき"
tokens = String.split(sentence, " ")
vocab =
  tokens
  |> Enum.uniq()
  |> Enum.sort()

token_to_id =
  vocab
  |> Enum.with_index()
  |> Map.new()

id_to_token =
  token_to_id
  |> Enum.map(fn {token, id} -> {id, token} end)
  |> Map.new()

token_ids = Enum.map(tokens, &token_to_id[&1])

%{
  vocab: vocab,
  token_to_id: token_to_id,
  token_ids: token_ids
}
token_rows =
  tokens
  |> Enum.with_index()
  |> Enum.map(fn {token, position} ->
    %{
      position: position,
      token: token,
      token_id: token_to_id[token]
    }
  end)

Kino.DataTable.new(token_rows)

LLM にとって、最初から「ねこ」の意味が分かっているわけではありません。まずは単に ID に置き換えて、計算できる形に直しています。

2. one-hot 表現から埋め込みへ

one-hot は「その語がどこにあるか」だけを表すベクトルです。

vocab_size = length(vocab)

sequence_one_hot =
  token_ids
  |> Enum.map(fn id ->
    Nx.equal(Nx.iota({vocab_size}), Nx.tensor(id))
  end)
  |> Nx.stack()

one-hot は便利ですが、語と語の近さを表しにくいです。そこで LLM では、より密な実数ベクトルである埋め込みを使います。

ここでは説明用に、2 次元の手作り埋め込みを置きます。

embedding_map = %{
  "ねこ" => [0.90, 0.15],
  "は" => [0.10, 0.05],
  "ひるね" => [0.95, 0.85],
  "が" => [0.12, 0.10],
  "すき" => [0.88, 0.70]
}

embedding_rows =
  vocab
  |> Enum.map(fn token ->
    [x, y] = embedding_map[token]

    %{
      token: token,
      dim_1: x,
      dim_2: y
    }
  end)

Kino.DataTable.new(
  embedding_rows,
  keys: [:token, :dim_1, :dim_2]
)
embedding_tensor =
  tokens
  |> Enum.map(&embedding_map[&1])
  |> Nx.tensor(type: {:f, 32})

embedding_tensor

埋め込みが学習されると、「似た働きをするトークン」や「一緒に現れやすいトークン」が近いベクトルになることがあります。

3. Attention の最小例

Attention は「今見ているトークンが、どのトークンをどれだけ参照するか」を決める仕組みです。

今回は すき が、文中のどこを見そうかを簡単な内積で観察します。

softmax_1d = fn scores ->
  exps = Nx.exp(Nx.subtract(scores, Nx.reduce_max(scores)))
  Nx.divide(exps, Nx.sum(exps))
end

scaled_dot_attention = fn query, keys, values ->
  d_model = keys |> Nx.shape() |> elem(1)
  scores = Nx.divide(Nx.dot(keys, [1], query, [0]), :math.sqrt(d_model))
  weights = softmax_1d.(scores)
  output = Nx.dot(weights, [0], values, [0])
  {output, weights}
end
query_index = Enum.find_index(tokens, &(&1 == "すき"))
query = embedding_tensor[query_index]

{attention_output, attention_weights} =
  scaled_dot_attention.(query, embedding_tensor, embedding_tensor)

%{
  query_token: Enum.at(tokens, query_index),
  output_vector: attention_output,
  weights: attention_weights
}
weight_rows =
  tokens
  |> Enum.with_index()
  |> Enum.map(fn {token, index} ->
    %{
      token: token,
      weight: Nx.to_number(attention_weights[index])
    }
  end)

LLMScratch.Visuals.bar_chart(weight_rows, "`すき` がどのトークンを見るか", :token, :weight)

この例では、すき と意味的に近く置いた ひるねねこ への重みが大きくなりやすくなります。

もちろん実際の LLM では、埋め込みは人手ではなく学習で作られます。

4. なぜスケーリングが必要か

埋め込み次元が大きいと、内積の値が大きくなりやすく、softmax が極端に尖ることがあります。

:rand.seed(:exsss, {11, 22, 33})

large_scores =
  1..12
  |> Enum.map(fn _ -> :rand.normal() * 8 end)

small_scores =
  1..12
  |> Enum.map(fn _ -> :rand.normal() end)

large_softmax = softmax_1d.(Nx.tensor(large_scores))
small_softmax = softmax_1d.(Nx.tensor(small_scores))

score_rows =
  Enum.with_index(large_scores)
  |> Enum.map(fn {score, index} ->
    %{
      index: index,
      large_scale: score,
      large_softmax: Nx.to_number(large_softmax[index]),
      small_softmax: Nx.to_number(small_softmax[index])
    }
  end)

Kino.DataTable.new(score_rows)
softmax_rows =
  score_rows
  |> Enum.flat_map(fn row ->
    [
      %{index: row.index, value: row.large_softmax, series: "値が大きいとき"},
      %{index: row.index, value: row.small_softmax, series: "値が小さいとき"}
    ]
  end)

LLMScratch.Visuals.line_chart(
  softmax_rows,
  "softmax の尖り方の比較",
  :index,
  :value,
  color_field: :series
)

この偏りを少しやわらげるために、Transformer では通常 sqrt(d_k) で割る scaled dot-product attention を使います。

5. Causal Mask を可視化する

GPT のような decoder-only モデルでは、未来の単語を先読みしてはいけません。

たとえば 3 番目のトークンは、4 番目や 5 番目を見てはいけません。これを実現するのが causal mask です。

decoder_tokens = ["わたし", "は", "きょう", "本", "を", "読む"]

decoder_embedding_map = %{
  "わたし" => [0.80, 0.10, 0.20],
  "は" => [0.10, 0.05, 0.05],
  "きょう" => [0.55, 0.85, 0.20],
  "本" => [0.92, 0.70, 0.15],
  "を" => [0.12, 0.12, 0.02],
  "読む" => [0.75, 0.88, 0.92]
}

decoder_embeddings =
  decoder_tokens
  |> Enum.map(&decoder_embedding_map[&1])
  |> Nx.tensor(type: {:f, 32})

decoder_embeddings
causal_mask_rows =
  for {query, i} <- Enum.with_index(decoder_tokens),
      {key, j} <- Enum.with_index(decoder_tokens) do
    %{
      index_q: i,
      index_k: j,
      query: query,
      key: key,
      score: if(j > i, do: 1.0, else: 0.0)
    }
  end

LLMScratch.Visuals.heatmap(causal_mask_rows, "1 の場所は未来なので見てはいけない")
row_softmax = fn matrix ->
  exps =
    Nx.exp(
      Nx.subtract(matrix, Nx.reduce_max(matrix, axes: [1], keep_axes: true))
    )

  Nx.divide(exps, Nx.sum(exps, axes: [1], keep_axes: true))
end

score_matrix =
  Nx.divide(
    Nx.dot(decoder_embeddings, [1], Nx.transpose(decoder_embeddings), [0]),
    :math.sqrt(3)
  )

mask_tensor =
  for i <- 0..(length(decoder_tokens) - 1) do
    for j <- 0..(length(decoder_tokens) - 1) do
      j > i
    end
  end
  |> Nx.tensor(type: {:u, 8})

masked_scores =
  Nx.select(mask_tensor, Nx.broadcast(-1.0e9, Nx.shape(score_matrix)), score_matrix)

masked_weights = row_softmax.(masked_scores)

masked_weight_rows =
  for {query, i} <- Enum.with_index(decoder_tokens),
      {key, j} <- Enum.with_index(decoder_tokens) do
    %{
      query: query,
      key: key,
      score: Nx.to_number(masked_weights[[i, j]])
    }
  end

LLMScratch.Visuals.heatmap(masked_weight_rows, "mask 適用後の attention 重み")

ヒートマップを見ると、各トークンは自分より右側を見ていません。これが「次トークン予測」が成立する理由のひとつです。

6. 位置情報はどう入るのか

Attention だけだと、単語の並び順が弱くなります。そこで位置情報を加えます。

ここでは固定の sin / cos 位置エンコーディングを小さく作ってみます。

positional_encoding = fn max_position, d_model ->
  for pos <- 0..(max_position - 1) do
    for dim <- 0..(d_model - 1) do
      angle = pos / :math.pow(10_000, (2 * div(dim, 2)) / d_model)

      if rem(dim, 2) == 0 do
        :math.sin(angle)
      else
        :math.cos(angle)
      end
    end
  end
end

pe_matrix = positional_encoding.(12, 8)
pe_matrix |> Enum.take(4)
pe_rows =
  pe_matrix
  |> Enum.with_index()
  |> Enum.flat_map(fn {row, position} ->
    row
    |> Enum.take(4)
    |> Enum.with_index()
    |> Enum.map(fn {value, dim} ->
      %{
        position: position,
        value: value,
        dimension: "dim_#{dim}"
      }
    end)
  end)

LLMScratch.Visuals.line_chart(
  pe_rows,
  "位置エンコーディングの波形",
  :position,
  :value,
  color_field: :dimension
)

位置ごとに違う波形を足すことで、モデルは「同じ単語でも何番目にあるか」を区別しやすくなります。

7. GPT 系モデルでは何が起きているか

Kino.Mermaid.new("""
flowchart TD
  A["入力トークン"] --> B["Token Embedding"]
  A --> C["Position Embedding"]
  B --> D["足し合わせる"]
  C --> D
  D --> E["Causal Self-Attention"]
  E --> F["MLP"]
  F --> G["Linear"]
  G --> H["次トークンの確率"]
""")

実際の GPT はもっと深いネットワークですが、基本の考え方はこの繰り返しです。

  • 埋め込みで数値化する
  • 自分より前の文脈を Attention で読む
  • 各位置で次トークンの確率を出す

8. ここまでのまとめ

今日の要点

  • トークン化は、文章を計算可能な部品へ分解する作業
  • 埋め込みは、トークンを意味のある実数ベクトルへ変換する層
  • Attention は、各トークンがどこを見るべきかを重みとして学ぶ仕組み
  • GPT では causal mask によって未来のトークンを見ない
  • 位置エンコーディングで単語の順序情報を補う

次のノートブックでは、ここで見た考え方が実際の GPT 系モデルの生成にどう現れるかを Bumblebee で触っていきます。