Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Let's build GPT: from scratch, in code, spelled out… but with Axon

gpt.livemd

Let’s build GPT: from scratch, in code, spelled out… but with Axon

Mix.install(
  [
    {:nx, "~> 0.6", override: true},
    {:axon, "~> 0.5"},
    {:exla, "~> 0.5.1"},
    {:kino, "~> 0.7.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Intro

TODO: add better intro…

Resources:

Prepare the data

text =
  __DIR__
  |> Path.join("input.txt")
  |> Path.expand()
  |> File.read!()

Our Vocabulary

chars =
  text
  |> String.codepoints()
  |> Enum.uniq()
  |> Enum.sort()

vocab_size = length(chars)

IO.inspect(Enum.join(chars), label: "chars")
IO.inspect(vocab_size, label: "vocab_size")

:ok

Basic Encoder / Decoder

We are going to develop some strategy to tokenize the input text: convert the raw text to some sequence of integers according to some vocabolary of possible elements.

In this case we are going to tokenize single chars: for each char there is a corresponding integer.

defmodule Minidecoder do
  @text __DIR__
        |> Path.join("input.txt")
        |> Path.expand()
        |> File.read!()

  @chars String.codepoints(@text) |> Enum.uniq() |> Enum.sort()

  # string-to-integer
  @stoi Enum.reduce(@chars, %{}, fn ch, acc -> Map.put(acc, ch, Enum.count(acc)) end)

  # integer-to-string
  @itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc -> Map.put(acc, i, ch) end)

  def vocab_size, do: length(@chars)

  def encode_char(char), do: @stoi[char]

  def decode_char(encoded_char), do: @itos[encoded_char]

  def encode(text) do
    text |> String.codepoints() |> Enum.map(&encode_char(&1))
  end

  def decode(encoded_list) do
    encoded_list |> Enum.map(&decode_char(&1)) |> Enum.join()
  end

  def tensor(text, opts \\ []) do
    Nx.tensor(encode(text), opts)
  end
end
vocab_size =
  Minidecoder.vocab_size()
  |> IO.inspect(label: "vocab size is")

Minidecoder.encode("hii there") |> IO.inspect(label: "Encoding")
Minidecoder.encode("hii there") |> Minidecoder.decode() |> IO.inspect(label: "Decoding")

:ok
data = Minidecoder.tensor(text, type: :s64)

Nx.shape(data) |> IO.inspect(label: "shape")
data[0..999] |> IO.inspect(label: "first 1000 chars")

:ok

Separate dataset in train and validate dataset

# Use the 90% of the data for training and the remaining for the validation
n = Kernel.round(Nx.size(data) * 0.9)

train_data = data[0..(n - 1)//1]
val_data = data[n..-1//1]

{train_data, val_data}

Data loader: batches of chunk of data

To train the transform with split the train dataset in different chunks to speed it up.

block_size = 8

# Please note that the range `[0..block_size]` is equivalent to `[:block_size+1]` in pyhton
chunk = train_data[0..block_size]

IO.inspect(chunk, label: "chunk")

# In a chunk of 9 chars, there are 8 individual example packed in there:
# 18 -> 47
# 18, 47 -> 56
# 18, 47, 56 -> 57
# …

x = train_data[0..(block_size - 1)]
y = train_data[1..block_size]

for t <- 0..(block_size - 1) do
  context = x[0..t] |> Nx.to_list() |> Enum.join(", ")
  target = y[t] |> Nx.to_number()
  IO.inspect("when input is [#{context}] the target is: #{target}")
end

:ok

Chunks Generation

seed = 1337
# how many independent sequences will we process in parallel?
batch_size = 4
# what is the maximum context length for predictions?
block_size = 8

get_batch = fn split ->
  key = Nx.Random.key(seed)

  data = if split == :train, do: train_data, else: val_data

  # Generate random lists (batch size) of indices, each list will
  # be used to "sample" the dataset and plug `x` and `y` tensor starting from
  # each index in the list, the slice has length `block_size` (therefore it needs to be
  # taken into account: `Nx.size(data) - block_size`).
  #
  # Example
  # ix = [5, 10, 40, 60]
  # x = [[46, 47, 51, 57, 43, 50, 44, 1], […], […], […]]
  # y = [[47, 51, 57, 43, 50, 44, 1, 40], […], […], […]]
  #
  # where
  # - x[[0, 0]] = 46 = data[5] (5 is the 1st index in the 1st list of indices)
  # - x[[0, 1]] = 47 = data[5 + 1]
  # - … and so on
  {ix, _new_key} =
    Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size}, type: :u32)

  ix = Nx.to_list(ix)

  # From each index in `ix`, slice the data and extract
  # a 1d tensor of size `block_size`, then stack them all together
  x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack()
  y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack()

  {x, y}
end

{xb, yb} = get_batch.(:train)

IO.inspect("INPUTS")
IO.inspect(Nx.shape(xb))
IO.inspect(xb)
IO.inspect("TARGETS")
IO.inspect(Nx.shape(yb))
IO.inspect(yb)

IO.inspect("--------------")

# batch dimension (B)
for b <- 0..(batch_size - 1) do
  # time dimension (T)
  for t <- 0..(block_size - 1) do
    context = xb[[b, 0..t]] |> Nx.to_list() |> Enum.join(", ")
    target = yb[[b, t]] |> Nx.to_number()
    IO.inspect("when input is [#{context}] the target is: #{target}")
  end
end

:ok
# Our input to the transformer

IO.inspect(xb)

:ok

Bigram Language Model

A bigram language model is a type of statistical language model that predicts the probability of a word in a sequence based on the previous word. (source https://www.analyticsvidhya.com/blog/2019/08/comprehensive-guide-language-model-nlp-python-code/)

# Hyperparameters
# B - batch dimension
batch_size = 4
# T - time dimension
block_size = 8

# inputs (xb) and targets (yb) are both {B, T} tensors of integers
shape =
  xb
  |> Nx.to_template()
  |> Nx.shape()

# An embedding layer initializes a kernel of shape `{vocab_size, embedding_size}`
# which acts as a lookup table for sequences of discrete tokens (e.g. sentences).
# Embeddings are typically used to obtain a dense representation of a sparse input space.
# https://hexdocs.pm/axon/0.5.1/Axon.html#embedding/4

bigram_model =
  Axon.input("sequence", shape: shape)
  # vocab_size = embedding_size = 65 
  |> Axon.embedding(65, 65)

Axon.Display.as_graph(bigram_model, Nx.template({batch_size, block_size}, :f32),
  direction: :top_down
)
{init_fn, predict_fn} = Axon.build(bigram_model, mode: :train)
params = init_fn.(Nx.to_template(xb), %{})

## PREDICTION

# We can make prediction of what's coming next.
#
# `preds` are our `logits` which are basically the scores
# for the next char in the sequence.
# We are predicting what's coming next base on an individual
# identity of a single token and at the moment is possible because
# the token are not aware of the context (they don't talk to each other)
%{prediction: preds} = predict_fn.(params, xb)
IO.inspect(preds, label: "predictions")

## LOSS

# `preds` are of this shape `{b, t, c}`, where `c` stands for `channel`.
# But Axon expect them to be of this shape `{b * t, c}` when computing
# the loss.
{b, t, c} = Nx.shape(preds)
y_pred = Nx.reshape(preds, {b * t, c})

# Same for the target, at the moment have the shape `{b, t}`
# and we want it of shape `{b * t, 1}`
y_true = Nx.reshape(yb, {:auto, 1})

# Let's inspect the shape
y_pred |> IO.inspect(label: "y_pred reshaped")
y_true |> IO.inspect(label: "y_true reshaped")

# Now we can compute the loss
#
# To measure the loss we can use the log loss which
# is implemented in Axon under the name `categorical_cross_entropy`
#
# The `loss` is the cross entropy between the targets `y_true` and
# the predictions `y_pred`.
#
# - we need to compute the loss across the entire minibatch
#   therefore we set the option `reduction: :mean`
# - we need to set `from_logits: true` because `y_pred` is not normalized
# - we need to set `sparse: true` because the inputs are integer values
#   corresponding to the target class
#
# https://hexdocs.pm/axon/0.5.1/Axon.Losses.html#categorical_cross_entropy/3
loss =
  Axon.Losses.categorical_cross_entropy(y_true, y_pred,
    from_logits: true,
    sparse: true,
    reduction: :mean
  )

# We expect the loss to be `-ln(1/65)` where 65 is our vocabulary size,
# and we are getting really close…

IO.inspect(Nx.to_number(loss), label: "loss")
IO.inspect(-:math.log(1 / 65), label: "expected loss")

:ok

Generating text with the bigram model

Multinomial distribution implementation

Before starting we need to define the function for the multinomial distribution which will be used to generate the next char. In nx is not available, but it is possible to implement it using Nx.Random.choice/4.

👉 Many many many thanks to @wtedw for coming up with a nice and clean implementation 👏👏👏.

> Multinomial distribution models the probability of each combination of successes in a series of independent trials. Use this distribution when there are more than two possible mutually exclusive outcomes for each trial, and each outcome has a fixed probability of success.

# source https://github.com/wtedw/gpt-from-scratch/blob/main/gpt-dev.livemd#multinomial

defmodule RandomPlus do
  import Nx.Defn

  defn multinomial(init_key, input, opts \\ []) do
    opts = keyword!(opts, num_samples: 1)
    num_samples = opts[:num_samples]

    {b, c} = Nx.shape(input)
    # initialize the result tensor `{b, num_samples}`
    initial_tensor = Nx.broadcast(0, {b, num_samples})

    # Generate the category tensors used to compute
    # the multinomial distribution for each row in the
    # `input` tensor.
    # (it will then be used as index value
    # to fetch the char from the embiddings table)
    #
    # e.g. if `c` is 4: `Nx.tensor([0, 1, 2, 3])`
    category_iota = Nx.iota({c}, type: :s32)

    {_i, _input, next_key, acc} =
      while {i = 0, input, key = init_key, acc = initial_tensor}, i < b do
        # Let's plug one of the row from the initial inputs
        # which will represent the probability distribution (shape `{c}`)
        i_batch_prob = input[i]

        # Generates random samples from a tensor with specified probabilities
        # `i_batch_prob`. The returned `i_samples` is `{num_samples}`.
        {i_samples, next_key} =
          Nx.Random.choice(key, category_iota, i_batch_prob, samples: num_samples)

        # Reshape the `i_samples` to `{1, num_samples}` and
        # update ith row in the acc to hold the new samples
        i_samples_reformatted = Nx.reshape(i_samples, {1, :auto})
        acc = Nx.put_slice(acc, [i, 0], i_samples_reformatted)

        {i + 1, input, next_key, acc}
      end

    {next_key, acc}
  end
end

Let’s play a bit with it to understand how it works…

probs =
  Nx.tensor([
    [0.2, 0.3, 0.1, 0.15, 0.25],
    [0.10, 0.10, 0.10, 0.10, 0.60],
    [0.0, 0.0, 0.0, 0.0, 1.00]
  ])

# Given some batched probability distribution, sample 5 values
# This is for demonstration purposes (we'll only need to sample 1 char when generating text)
{_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 5)
IO.inspect(samples, label: "Num. samples 5")

{_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 1)
IO.inspect(samples, label: "Num. samples 1")

:ok

Text generation

The generate function takes the same inputs which represent the current context of some chars in a batch. The job of the generate is to take a {B, T} and extend it to be {B, T + 1}, {B, T + 2} and so on. Basically to continue the generation in all the batch dimension in the time dimension.

# Let's assemble the code above inside a module
# for convenience and add the `generate` function

defmodule BigramLanguageModel do
  import Nx.Defn

  def init(vocab_size) do
    # vocab_size = embedding_size
    Axon.input("sequence")
    |> Axon.embedding(vocab_size, vocab_size)
    |> Axon.build(mode: :train)
  end

  def forward(predict_fn, params, idx, targets \\ :none) do
    %{prediction: preds} = predict_fn.(params, idx)

    if targets == :none do
      {preds, nil}
    else
      # Compute it twice for the sake of keep it simple
      {b, t, c} = Nx.shape(preds)
      y_pred = Nx.reshape(preds, {b * t, c})

      loss = loss(targets, preds)

      {y_pred, loss}
    end
  end

  defn loss(targets, preds) do
    {b, t, c} = Nx.shape(preds)

    # `y_pred` are the `logits` in the video
    y_pred = Nx.reshape(preds, {b * t, c})
    y_true = Nx.reshape(targets, {:auto, 1})

    Axon.Losses.categorical_cross_entropy(y_true, y_pred,
      from_logits: true,
      sparse: true,
      reduction: :mean
    )
  end

  def generate(predict_fn, params, idx, opts) do
    opts = Keyword.validate!(opts, pnrg_key_seed: 1337, max_new_tokens: 1000)

    pnrg_key = opts[:pnrg_key_seed] |> Nx.Random.key()
    max_new_tokens = opts[:max_new_tokens]

    Enum.reduce(1..max_new_tokens, {pnrg_key, idx}, fn _i, {pnrg_key, acc} ->
      # get the prediction
      {y_pred, _loss} = BigramLanguageModel.forward(predict_fn, params, acc)

      # focus only on the last time step
      # becomes {B, C}
      logits = y_pred[[.., -1, ..]]

      # apply softmax to get probabilities
      # {B, C}
      probs = Axon.Activations.softmax(logits)

      # sample from the distribution
      # `{B, 1}`
      {next_pnrg_key, idx_next} = RandomPlus.multinomial(pnrg_key, probs, num_samples: 1)

      # {B, T + 1}
      {next_pnrg_key, Nx.concatenate([acc, idx_next], axis: 1)}
    end)
    # Keep only the generated text and discard they pnrg key
    |> then(fn {_next_pnrg_key, acc} -> acc end)
    # Convert our Nx.tensor to Elixir list
    |> Nx.to_list()
    # Decode the results
    |> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end)
    # We have only 1 batch, so just unplug the 1st item
    |> List.first()
  end
end
# Text Generation

{init_fn, predict_fn} = BigramLanguageModel.init(vocab_size)
initial_params = init_fn.(Nx.to_template(xb), %{})
IO.inspect(initial_params, label: "initial_params")

{logits, loss} = BigramLanguageModel.forward(predict_fn, initial_params, xb, yb)
IO.inspect(Nx.shape(logits), label: "logits shape")
IO.inspect(loss, label: "loss")

# Create the `idx` tensor and
# kick-off the generation starting from `0`.
# `0` encodes the `\n` character which means
# that we are starting the sequence with a 
# new line. 
#
# equivalent to `torch.zeros(1, 1)`
idx = Nx.broadcast(Nx.tensor(0), {1, 1})

max_new_tokens = 100

seed = 1337
initial_pnrg_key = Nx.Random.key(seed)

BigramLanguageModel.generate(predict_fn, initial_params, idx, max_new_tokens: max_new_tokens)
|> IO.puts()

👆 At the moment the returned sequence is useless 🗑️

This because:

  • the model is “random” and not trained
  • the generate function is written to be general: we are accumulating the context and use it to feed the model, but then when it is time to predict the next char, we only look at the last char of the context. That’s because it is a simple Bigram model.

Training the Bigram Model (with Axon)

Let’s train the model to make it less random! But before starting we need to change the get_batch utility function to accept the PRNG (pseudo-random number generator) key as argument in order to get different batches all the time. This is because in Nx.Random is stateless (https://hexdocs.pm/nx/Nx.Random.html).

get_batch = fn split, batch_size, block_size, pnrg_key ->
  data = if split == :train, do: train_data, else: val_data

  {ix, new_pnrg_key} =
    Nx.Random.randint(pnrg_key, 0, Nx.size(data) - block_size, shape: {batch_size}, type: :u32)

  ix = Nx.to_list(ix)

  x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack()
  y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack()

  %{batch: {x, y}, pnrg_key: new_pnrg_key}
end
batch_size = 32
block_size = 8

iterations = 10
seed = 1337
initial_pnrg_key = Nx.Random.key(seed)

# GENERATING THE BATCHES IN THIS WAY IS VERY SLOW!
# I kept the number of iterations low for that reasons
# (to generate 500 batches it took 200 seconds)
%{batches: train_batches} =
  Enum.reduce(0..(iterations - 1), %{pnrg_key: initial_pnrg_key, batches: []}, fn _n, acc ->
    %{pnrg_key: pnrg_key, batches: batches} = acc

    %{batch: {xb, yb}, pnrg_key: new_pnrg_key} =
      get_batch.(:train, batch_size, block_size, pnrg_key)

    %{pnrg_key: new_pnrg_key, batches: batches ++ [{xb, yb}]}
  end)

model = BigramLanguageModel.init(vocab_size)
lr = 1.0e-3
optimizer = Axon.Optimizers.adamw(lr)

model
|> Axon.Loop.trainer(&amp;BigramLanguageModel.loss/2, optimizer)
|> Axon.Loop.run(train_batches, %{}, epochs: 1, iterations: iterations, compiler: EXLA)

👆 The batch generation is rather slow, it takes 200 seconds to generate 500 batches. Let’s try to speed it up by making it lazy (via strema) 💤 and generate the batch in a numerical function.

Highly inspired by:

defmodule DataLoader do
  import Nx.Defn

  defn get_batch(data, key, opts \\ []) do
    opts = keyword!(opts, [:batch_size, :block_size])

    block_size = opts[:block_size]
    batch_size = opts[:batch_size]

    {n} = Nx.shape(data)

    # Generate a `{batch_size, 1}` tensor of random numbers between
    # `0` and `n - block_size`, they represent the start index for every
    # block that composes the batch.
    {ix, key} = Nx.Random.randint(key, 0, n - block_size, shape: {batch_size, 1})

    # For each number, generate a sequence by summing
    # a tensor `{1, block_size}`
    # For example, if `block_size` is 8,
    # `Nx.iota({1, 8})` will return:
    #
    # ```
    # [
    #   [0, 1, 2, 3, 4, 5, 6, 7]
    # ]
    # ```
    #
    # And `ix +  Nx.iota({1, block_size})` will return
    # a new tensor `{batch_size, block_size}` because
    # the addition will broadcast tensors whenever the
    # dimensions do not match and broadcasting is possible.
    # https://hexdocs.pm/nx/Nx.html#add/2
    x_indices = ix + Nx.iota({1, block_size})

    x = Nx.take(data, x_indices)

    # Shift by 1 for `y` tensor to
    # get the following char.
    y = Nx.take(data, x_indices + 1)

    {x, y, key}
  end
end

seed = 1337

get_streamed_batch = fn split, batch_size, block_size ->
  Stream.resource(
    fn -> Nx.Random.key(seed) end,
    fn pnrg_key ->
      data = if split == :train, do: train_data, else: val_data

      {xb, yb, new_pnrg_key} =
        DataLoader.get_batch(data, pnrg_key, block_size: block_size, batch_size: batch_size)

      {[{xb, yb}], new_pnrg_key}
    end,
    fn _ -> :ok end
  )
end
# Let's inspect 2 batches `[{xb, yb}, {xb, yb}]`

streamed_batches = get_streamed_batch.(:train, batch_size, block_size)
Enum.take(streamed_batches, 2)
# Hyperparams
batch_size = 32
block_size = 8
# increase number of steps for good results...
iterations = 100

model = BigramLanguageModel.init(vocab_size)
lr = 1.0e-3
optimizer = Axon.Optimizers.adamw(lr)

# Use Axon to construct the training loop
params =
  model
  |> Axon.Loop.trainer(&amp;BigramLanguageModel.loss/2, optimizer)
  |> Axon.Loop.run(streamed_batches, %{}, epochs: 1, iterations: iterations, compiler: EXLA)

(side note: i feel that training loop is a bit slow, but I should verify how long it takes the python implementation. With my notebook (macbook pro intel i7) it takes 6 seconds ca. for 100 iterations.)

# And check the loss with the params after # iterations of training loop

[{xb, yb}] = get_streamed_batch.(:train, batch_size, block_size) |> Enum.take(1)
{_logits, loss} = BigramLanguageModel.forward(predict_fn, params, xb, yb)
loss
# And finally generate some text after the training…

# Let's generate more token this time
max_new_tokens = 500

BigramLanguageModel.generate(predict_fn, params, idx, max_new_tokens: max_new_tokens)
|> IO.puts()
# For simplicity I let the model train for 10_000 iterations and dump the
# params in a file. Let's load it and see how the text generations looks like…

params10000 =
  __DIR__
  |> Path.join("bigram_model_params_10000")
  |> Path.expand()
  |> File.read!()
  |> Nx.deserialize()

[{xb, yb}] = get_streamed_batch.(:train, batch_size, block_size) |> Enum.take(1)
{_logits, loss} = BigramLanguageModel.forward(predict_fn, params10000, xb, yb)
IO.inspect(loss, label: "Loss after 10_000 iterations")

max_new_tokens = 500

BigramLanguageModel.generate(predict_fn, params10000, idx, max_new_tokens: max_new_tokens)
|> IO.puts()

The mathematical trick in self-attention (version 1)

# consider the following toy example

# batch, time, channel
# `channel` encodes "some" information at each point in the sequence
{b, t, c} = {4, 8, 2}

{x, _key} = Nx.Random.normal(Nx.Random.key(1337), shape: {b, t, c})

We would like this tokens (up to 8 tokens in a batch) and we would like to couple them and let them communicates in a very specific way: information shuold flow from previous context to the current timestamp and we can not have any information from the future because we are trying to predict them. The simplest way to make the token communicate is to do an average of all the preceeding elements: is a lossy way to make token communication but that’s ok for the moment.

# xbow - bag of words is a term used to identify that we are averaging up things

# Let's initialize to zeroes
xbow = Nx.broadcast(Nx.tensor(0), {b, t, c})

# We want x[b, t] = mean_{i<=t} x[b, i]
xbow =
  for b_i <- 0..(b - 1), reduce: xbow do
    xbow ->
      for t_i <- 0..(t - 1), reduce: xbow do
        xbow ->
          # Shape `{t^, c}` where `t^` is how many elements there were in the past
          # (remember, range is inclusive in Elixir)
          xprev = x[[b_i, 0..t_i]]

          mean = Nx.mean(xprev, axes: [0])

          # `xbow` has rank 3, while `mean` has rank 1.
          # In order to inject the computed `mean` in `xbow` tensor
          # via `Nx.put_slice/3` we need to make them compatible,
          # aka, update `mean` tensor to have the same rank.
          wrapped_mean = Nx.reshape(mean, {1, 1, :auto})
          Nx.put_slice(xbow, [b_i, t_i, 0], wrapped_mean)
      end
  end

Let’s take a look at the result:

x[0]
xbow[0]
# Let's "manually" compute the average of the first 2 elements of x (along axis 0)
avg = (Nx.to_number(x[0][0][0]) + Nx.to_number(x[0][1][0])) / 2

# This should be the same of the 2nd element in xbow (along axis 0)
IO.inspect("AVG: #{avg} ~= xbow[0][1][0]: #{Nx.to_number(xbow[0][1][0])}")

:ok

The mathematical trick in self-attention (version 2)

The previous implementation (version 1) is very inefficient. We can be more efficient by using batch matrix multiply in combination with the Nx.tril function to do this weighted aggregation. And the weights are specified in this {t, t} tensor (wei), and we are basically doing weighted sums according to the triangular form of the wei tensor, so that means that a token at the nth position will only get the information from the preeciding tokens.

# Example to see how bath matrix multiply works
# - the tensor `a` is renamed `t1`
# - the tensor `b` is renamed `t2`
# - the tensor `c` is renamed `t3`
t1 = Nx.broadcast(1.0, {3, 3}) |> Nx.tril()
t1 = Nx.divide(t1, Nx.sum(t1, axes: [1], keep_axes: true))
{t2, _key} = Nx.Random.randint(Nx.Random.key(123), 1, 10, shape: {3, 2})
t2 = Nx.as_type(t2, :f32)
t3 = Nx.dot(t1, t2)

IO.inspect(t1, label: "t1 (a)")
IO.inspect(t2, label: "t2 (b)")
IO.inspect(t3, label: "t3 (c)")

:ok
wei = Nx.broadcast(1.0, {t, t}) |> Nx.tril()
wei = Nx.divide(wei, Nx.sum(wei, axes: [1], keep_axes: true))

# NOTE:
# We can not simply do `Nx.dot(wei, x)` as in pytorch
# because in this case `Nx.dot/2` does not automatically broadcast
# `wei` across batch dimension `b` of `x` tensor.
#
# Therefore, what we can do is to add the batch dimension
# to our `wei`, from `{t, t}` to `{b, t, t}` by broadcasting it,
# basically it copies the `{t, t}` weights tensor for each batch.
wei = Nx.broadcast(wei, {b, t, t})

# Given that `wei` is `{b, t, t}` and `x` is `{b, t, c}`,
# and that we want the returned tensor to have the same shape of `x`,
# we need to specify the contracting and batch axes
# https://hexdocs.pm/nx/Nx.html#dot/6-dot-product-between-two-batched-tensors
xbow2 = Nx.dot(wei, [2], [0], x, [1], [0])

# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, xbow2)
# Let's manually compare the 1st batch of `xbow` and `xbow2`
IO.inspect(xbow[0], label: "xbow[0]")
IO.inspect(xbow2[0], label: "xbow2[0]")

:ok

Alternative approach using vectorization

Given that we want to do the dot product along the batch dimension b, we can take advantage of Nx.vectorize

wei = Nx.broadcast(1.0, {t, t}) |> Nx.tril()
wei = Nx.divide(wei, Nx.sum(wei, axes: [1], keep_axes: true))

# Here the magic… no need to specify contracting and batch axes anymore
# when computing the matrix multiplication.
vec_x = Nx.vectorize(x, :batch)
xbow2_vect = Nx.dot(wei, vec_x)
IO.inspect(xbow2_vect)

# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, Nx.devectorize(xbow2_vect, keep_names: false))

The mathematical trick in self-attention (version 3)

Use Softmax to normalize the weights tensor.

# The weights are initialized to 0.0 and we can think of it
# like an interaction strength or affinity, so basically is telling
# us how much from the past token we want to aggregate and average up.
# The weights value is going to change and be data dependent: the tokens will start
# to "look" at each other and find other tokens more or less interesting,
# that's the _affinity_.
wei = Nx.broadcast(0.0, {t, t})

tril = Nx.broadcast(1.0, {t, t}) |> Nx.tril()

# In particular, this line is telling us that tokens from the past
# cannot communicate with neg. infinity tokens, basically the
# neg. infinity tokens won't be aggregated
wei = Nx.select(Nx.equal(tril, 0.0), Nx.Constants.neg_infinity(), wei)
IO.inspect(wei, label: "Initial weights")

# Equivalent result by using `Nx.triu/2`
# wei = Nx.broadcast(Nx.Constants.neg_infinity(), {t, t}) |> Nx.triu(k: 1)

# Normalize
wei = Axon.Activations.softmax(wei)
IO.inspect(wei, label: "Weights after softmax")

# Aggregation by matrix multiplication `wei @ x`.
# It aggregates their values depending on the affinity
# between tokens.
#
# NOTE:
# We can not simply do `Nx.dot(wei, x)` as in pytorch
# because in this case `Nx.dot/2` does not automatically broadcast
# `wei` across batch dimension `b` of `x` tensor.
#
# Therefore, what we can do is to add the batch dimension
# to our `wei`, from `{t, t}` to `{b, t, t}` by broadcasting it,
# basically it copies the `{t, t}` weights tensor for each batch.
wei = Nx.broadcast(wei, {b, t, t})

# Given that `wei` is `{b, t, t}` and `x` is `{b, t, c}`,
# and that we want the returned tensor to have the same shape of `x`,
# we need to specify the contracting and batch axes
# https://hexdocs.pm/nx/Nx.html#dot/6-dot-product-between-two-batched-tensors
xbow3 = Nx.dot(wei, [2], [0], x, [1], [0])
IO.inspect(xbow3, label: "xbow3")

# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, xbow3)