Let’s build GPT: from scratch, in code, spelled out… but with Axon
Mix.install(
[
{:nx, "~> 0.6", override: true},
{:axon, "~> 0.5"},
{:exla, "~> 0.5.1"},
{:kino, "~> 0.7.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Intro
TODO: add better intro…
Resources:
 https://www.youtube.com/watch?v=kCc8FmEb1nY
 https://colab.research.google.com/drive/1JMLa53HDuAi7ZBmqV7ZnA3c_fvtXnx?usp=sharing#scrollTo=wJpXpmjEYC_T
 https://github.com/wtedw/gptfromscratch/blob/main/gptdev.livemd
Prepare the data
text =
__DIR__
> Path.join("input.txt")
> Path.expand()
> File.read!()
Our Vocabulary
chars =
text
> String.codepoints()
> Enum.uniq()
> Enum.sort()
vocab_size = length(chars)
IO.inspect(Enum.join(chars), label: "chars")
IO.inspect(vocab_size, label: "vocab_size")
:ok
Basic Encoder / Decoder
We are going to develop some strategy to tokenize the input text: convert the raw text to some sequence of integers according to some vocabolary of possible elements.
In this case we are going to tokenize single chars: for each char there is a corresponding integer.
defmodule Minidecoder do
@text __DIR__
> Path.join("input.txt")
> Path.expand()
> File.read!()
@chars String.codepoints(@text) > Enum.uniq() > Enum.sort()
# stringtointeger
@stoi Enum.reduce(@chars, %{}, fn ch, acc > Map.put(acc, ch, Enum.count(acc)) end)
# integertostring
@itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc > Map.put(acc, i, ch) end)
def vocab_size, do: length(@chars)
def encode_char(char), do: @stoi[char]
def decode_char(encoded_char), do: @itos[encoded_char]
def encode(text) do
text > String.codepoints() > Enum.map(&encode_char(&1))
end
def decode(encoded_list) do
encoded_list > Enum.map(&decode_char(&1)) > Enum.join()
end
def tensor(text, opts \\ []) do
Nx.tensor(encode(text), opts)
end
end
vocab_size =
Minidecoder.vocab_size()
> IO.inspect(label: "vocab size is")
Minidecoder.encode("hii there") > IO.inspect(label: "Encoding")
Minidecoder.encode("hii there") > Minidecoder.decode() > IO.inspect(label: "Decoding")
:ok
data = Minidecoder.tensor(text, type: :s64)
Nx.shape(data) > IO.inspect(label: "shape")
data[0..999] > IO.inspect(label: "first 1000 chars")
:ok
Separate dataset in train and validate dataset
# Use the 90% of the data for training and the remaining for the validation
n = Kernel.round(Nx.size(data) * 0.9)
train_data = data[0..(n  1)//1]
val_data = data[n..1//1]
{train_data, val_data}
Data loader: batches of chunk of data
To train the transform with split the train dataset in different chunks to speed it up.
block_size = 8
# Please note that the range `[0..block_size]` is equivalent to `[:block_size+1]` in pyhton
chunk = train_data[0..block_size]
IO.inspect(chunk, label: "chunk")
# In a chunk of 9 chars, there are 8 individual example packed in there:
# 18 > 47
# 18, 47 > 56
# 18, 47, 56 > 57
# …
x = train_data[0..(block_size  1)]
y = train_data[1..block_size]
for t < 0..(block_size  1) do
context = x[0..t] > Nx.to_list() > Enum.join(", ")
target = y[t] > Nx.to_number()
IO.inspect("when input is [#{context}] the target is: #{target}")
end
:ok
Chunks Generation
seed = 1337
# how many independent sequences will we process in parallel?
batch_size = 4
# what is the maximum context length for predictions?
block_size = 8
get_batch = fn split >
key = Nx.Random.key(seed)
data = if split == :train, do: train_data, else: val_data
# Generate random lists (batch size) of indices, each list will
# be used to "sample" the dataset and plug `x` and `y` tensor starting from
# each index in the list, the slice has length `block_size` (therefore it needs to be
# taken into account: `Nx.size(data)  block_size`).
#
# Example
# ix = [5, 10, 40, 60]
# x = [[46, 47, 51, 57, 43, 50, 44, 1], […], […], […]]
# y = [[47, 51, 57, 43, 50, 44, 1, 40], […], […], […]]
#
# where
#  x[[0, 0]] = 46 = data[5] (5 is the 1st index in the 1st list of indices)
#  x[[0, 1]] = 47 = data[5 + 1]
#  … and so on
{ix, _new_key} =
Nx.Random.randint(key, 0, Nx.size(data)  block_size, shape: {batch_size}, type: :u32)
ix = Nx.to_list(ix)
# From each index in `ix`, slice the data and extract
# a 1d tensor of size `block_size`, then stack them all together
x = Enum.map(ix, fn i > Nx.slice(data, [i], [block_size]) end) > Nx.stack()
y = Enum.map(ix, fn i > Nx.slice(data, [i + 1], [block_size]) end) > Nx.stack()
{x, y}
end
{xb, yb} = get_batch.(:train)
IO.inspect("INPUTS")
IO.inspect(Nx.shape(xb))
IO.inspect(xb)
IO.inspect("TARGETS")
IO.inspect(Nx.shape(yb))
IO.inspect(yb)
IO.inspect("")
# batch dimension (B)
for b < 0..(batch_size  1) do
# time dimension (T)
for t < 0..(block_size  1) do
context = xb[[b, 0..t]] > Nx.to_list() > Enum.join(", ")
target = yb[[b, t]] > Nx.to_number()
IO.inspect("when input is [#{context}] the target is: #{target}")
end
end
:ok
# Our input to the transformer
IO.inspect(xb)
:ok
Bigram Language Model
A bigram language model is a type of statistical language model that predicts the probability of a word in a sequence based on the previous word. (source https://www.analyticsvidhya.com/blog/2019/08/comprehensiveguidelanguagemodelnlppythoncode/)
# Hyperparameters
# B  batch dimension
batch_size = 4
# T  time dimension
block_size = 8
# inputs (xb) and targets (yb) are both {B, T} tensors of integers
shape =
xb
> Nx.to_template()
> Nx.shape()
# An embedding layer initializes a kernel of shape `{vocab_size, embedding_size}`
# which acts as a lookup table for sequences of discrete tokens (e.g. sentences).
# Embeddings are typically used to obtain a dense representation of a sparse input space.
# https://hexdocs.pm/axon/0.5.1/Axon.html#embedding/4
bigram_model =
Axon.input("sequence", shape: shape)
# vocab_size = embedding_size = 65
> Axon.embedding(65, 65)
Axon.Display.as_graph(bigram_model, Nx.template({batch_size, block_size}, :f32),
direction: :top_down
)
{init_fn, predict_fn} = Axon.build(bigram_model, mode: :train)
params = init_fn.(Nx.to_template(xb), %{})
## PREDICTION
# We can make prediction of what's coming next.
#
# `preds` are our `logits` which are basically the scores
# for the next char in the sequence.
# We are predicting what's coming next base on an individual
# identity of a single token and at the moment is possible because
# the token are not aware of the context (they don't talk to each other)
%{prediction: preds} = predict_fn.(params, xb)
IO.inspect(preds, label: "predictions")
## LOSS
# `preds` are of this shape `{b, t, c}`, where `c` stands for `channel`.
# But Axon expect them to be of this shape `{b * t, c}` when computing
# the loss.
{b, t, c} = Nx.shape(preds)
y_pred = Nx.reshape(preds, {b * t, c})
# Same for the target, at the moment have the shape `{b, t}`
# and we want it of shape `{b * t, 1}`
y_true = Nx.reshape(yb, {:auto, 1})
# Let's inspect the shape
y_pred > IO.inspect(label: "y_pred reshaped")
y_true > IO.inspect(label: "y_true reshaped")
# Now we can compute the loss
#
# To measure the loss we can use the log loss which
# is implemented in Axon under the name `categorical_cross_entropy`
#
# The `loss` is the cross entropy between the targets `y_true` and
# the predictions `y_pred`.
#
#  we need to compute the loss across the entire minibatch
# therefore we set the option `reduction: :mean`
#  we need to set `from_logits: true` because `y_pred` is not normalized
#  we need to set `sparse: true` because the inputs are integer values
# corresponding to the target class
#
# https://hexdocs.pm/axon/0.5.1/Axon.Losses.html#categorical_cross_entropy/3
loss =
Axon.Losses.categorical_cross_entropy(y_true, y_pred,
from_logits: true,
sparse: true,
reduction: :mean
)
# We expect the loss to be `ln(1/65)` where 65 is our vocabulary size,
# and we are getting really close…
IO.inspect(Nx.to_number(loss), label: "loss")
IO.inspect(:math.log(1 / 65), label: "expected loss")
:ok
Generating text with the bigram model
Multinomial distribution implementation
Before starting we need to define the function for the multinomial distribution which will be used to generate the next char. In nx
is not available, but it is possible to implement it using Nx.Random.choice/4
.
👉 Many many many thanks to @wtedw for coming up with a nice and clean implementation 👏👏👏.
> Multinomial distribution models the probability of each combination of successes in a series of independent trials. Use this distribution when there are more than two possible mutually exclusive outcomes for each trial, and each outcome has a fixed probability of success.
 https://it.mathworks.com/help/stats/multinomialdistribution.html
 https://pytorch.org/docs/stable/generated/torch.multinomial.html
# source https://github.com/wtedw/gptfromscratch/blob/main/gptdev.livemd#multinomial
defmodule RandomPlus do
import Nx.Defn
defn multinomial(init_key, input, opts \\ []) do
opts = keyword!(opts, num_samples: 1)
num_samples = opts[:num_samples]
{b, c} = Nx.shape(input)
# initialize the result tensor `{b, num_samples}`
initial_tensor = Nx.broadcast(0, {b, num_samples})
# Generate the category tensors used to compute
# the multinomial distribution for each row in the
# `input` tensor.
# (it will then be used as index value
# to fetch the char from the embiddings table)
#
# e.g. if `c` is 4: `Nx.tensor([0, 1, 2, 3])`
category_iota = Nx.iota({c}, type: :s32)
{_i, _input, next_key, acc} =
while {i = 0, input, key = init_key, acc = initial_tensor}, i < b do
# Let's plug one of the row from the initial inputs
# which will represent the probability distribution (shape `{c}`)
i_batch_prob = input[i]
# Generates random samples from a tensor with specified probabilities
# `i_batch_prob`. The returned `i_samples` is `{num_samples}`.
{i_samples, next_key} =
Nx.Random.choice(key, category_iota, i_batch_prob, samples: num_samples)
# Reshape the `i_samples` to `{1, num_samples}` and
# update ith row in the acc to hold the new samples
i_samples_reformatted = Nx.reshape(i_samples, {1, :auto})
acc = Nx.put_slice(acc, [i, 0], i_samples_reformatted)
{i + 1, input, next_key, acc}
end
{next_key, acc}
end
end
Let’s play a bit with it to understand how it works…
probs =
Nx.tensor([
[0.2, 0.3, 0.1, 0.15, 0.25],
[0.10, 0.10, 0.10, 0.10, 0.60],
[0.0, 0.0, 0.0, 0.0, 1.00]
])
# Given some batched probability distribution, sample 5 values
# This is for demonstration purposes (we'll only need to sample 1 char when generating text)
{_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 5)
IO.inspect(samples, label: "Num. samples 5")
{_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 1)
IO.inspect(samples, label: "Num. samples 1")
:ok
Text generation
The generate
function takes the same inputs which represent the current context of some chars in a batch. The job of the generate
is to take a {B, T}
and extend it to be {B, T + 1}
, {B, T + 2}
and so on. Basically to continue the generation in all the batch dimension in the time dimension.
# Let's assemble the code above inside a module
# for convenience and add the `generate` function
defmodule BigramLanguageModel do
import Nx.Defn
def init(vocab_size) do
# vocab_size = embedding_size
Axon.input("sequence")
> Axon.embedding(vocab_size, vocab_size)
> Axon.build(mode: :train)
end
def forward(predict_fn, params, idx, targets \\ :none) do
%{prediction: preds} = predict_fn.(params, idx)
if targets == :none do
{preds, nil}
else
# Compute it twice for the sake of keep it simple
{b, t, c} = Nx.shape(preds)
y_pred = Nx.reshape(preds, {b * t, c})
loss = loss(targets, preds)
{y_pred, loss}
end
end
defn loss(targets, preds) do
{b, t, c} = Nx.shape(preds)
# `y_pred` are the `logits` in the video
y_pred = Nx.reshape(preds, {b * t, c})
y_true = Nx.reshape(targets, {:auto, 1})
Axon.Losses.categorical_cross_entropy(y_true, y_pred,
from_logits: true,
sparse: true,
reduction: :mean
)
end
def generate(predict_fn, params, idx, opts) do
opts = Keyword.validate!(opts, pnrg_key_seed: 1337, max_new_tokens: 1000)
pnrg_key = opts[:pnrg_key_seed] > Nx.Random.key()
max_new_tokens = opts[:max_new_tokens]
Enum.reduce(1..max_new_tokens, {pnrg_key, idx}, fn _i, {pnrg_key, acc} >
# get the prediction
{y_pred, _loss} = BigramLanguageModel.forward(predict_fn, params, acc)
# focus only on the last time step
# becomes {B, C}
logits = y_pred[[.., 1, ..]]
# apply softmax to get probabilities
# {B, C}
probs = Axon.Activations.softmax(logits)
# sample from the distribution
# `{B, 1}`
{next_pnrg_key, idx_next} = RandomPlus.multinomial(pnrg_key, probs, num_samples: 1)
# {B, T + 1}
{next_pnrg_key, Nx.concatenate([acc, idx_next], axis: 1)}
end)
# Keep only the generated text and discard they pnrg key
> then(fn {_next_pnrg_key, acc} > acc end)
# Convert our Nx.tensor to Elixir list
> Nx.to_list()
# Decode the results
> Enum.map(fn encoded_list > Minidecoder.decode(encoded_list) end)
# We have only 1 batch, so just unplug the 1st item
> List.first()
end
end
# Text Generation
{init_fn, predict_fn} = BigramLanguageModel.init(vocab_size)
initial_params = init_fn.(Nx.to_template(xb), %{})
IO.inspect(initial_params, label: "initial_params")
{logits, loss} = BigramLanguageModel.forward(predict_fn, initial_params, xb, yb)
IO.inspect(Nx.shape(logits), label: "logits shape")
IO.inspect(loss, label: "loss")
# Create the `idx` tensor and
# kickoff the generation starting from `0`.
# `0` encodes the `\n` character which means
# that we are starting the sequence with a
# new line.
#
# equivalent to `torch.zeros(1, 1)`
idx = Nx.broadcast(Nx.tensor(0), {1, 1})
max_new_tokens = 100
seed = 1337
initial_pnrg_key = Nx.Random.key(seed)
BigramLanguageModel.generate(predict_fn, initial_params, idx, max_new_tokens: max_new_tokens)
> IO.puts()
👆 At the moment the returned sequence is useless 🗑️
This because:
 the model is “random” and not trained

the
generate
function is written to be general: we are accumulating the context and use it to feed the model, but then when it is time to predict the next char, we only look at the last char of the context. That’s because it is a simple Bigram model.
Training the Bigram Model (with Axon)
Let’s train the model to make it less random! But before starting we need to change the get_batch
utility function to accept the PRNG (pseudorandom number generator) key as argument in order to get different batches all the time. This is because in Nx.Random
is stateless (https://hexdocs.pm/nx/Nx.Random.html).
get_batch = fn split, batch_size, block_size, pnrg_key >
data = if split == :train, do: train_data, else: val_data
{ix, new_pnrg_key} =
Nx.Random.randint(pnrg_key, 0, Nx.size(data)  block_size, shape: {batch_size}, type: :u32)
ix = Nx.to_list(ix)
x = Enum.map(ix, fn i > Nx.slice(data, [i], [block_size]) end) > Nx.stack()
y = Enum.map(ix, fn i > Nx.slice(data, [i + 1], [block_size]) end) > Nx.stack()
%{batch: {x, y}, pnrg_key: new_pnrg_key}
end
batch_size = 32
block_size = 8
iterations = 10
seed = 1337
initial_pnrg_key = Nx.Random.key(seed)
# GENERATING THE BATCHES IN THIS WAY IS VERY SLOW!
# I kept the number of iterations low for that reasons
# (to generate 500 batches it took 200 seconds)
%{batches: train_batches} =
Enum.reduce(0..(iterations  1), %{pnrg_key: initial_pnrg_key, batches: []}, fn _n, acc >
%{pnrg_key: pnrg_key, batches: batches} = acc
%{batch: {xb, yb}, pnrg_key: new_pnrg_key} =
get_batch.(:train, batch_size, block_size, pnrg_key)
%{pnrg_key: new_pnrg_key, batches: batches ++ [{xb, yb}]}
end)
model = BigramLanguageModel.init(vocab_size)
lr = 1.0e3
optimizer = Axon.Optimizers.adamw(lr)
model
> Axon.Loop.trainer(&BigramLanguageModel.loss/2, optimizer)
> Axon.Loop.run(train_batches, %{}, epochs: 1, iterations: iterations, compiler: EXLA)
👆 The batch generation is rather slow, it takes 200 seconds to generate 500 batches. Let’s try to speed it up by making it lazy (via strema) 💤 and generate the batch in a numerical function.
Highly inspired by:
 @wdetw for the stream idea https://github.com/wtedw/gptfromscratch/blob/main/gptdev.livemd 👏)
 @polvalente for the numerical function https://elixirforum.com/t/letsbuildgptwithnxhowtotranslatethiscodeinelixir/56143/7
defmodule DataLoader do
import Nx.Defn
defn get_batch(data, key, opts \\ []) do
opts = keyword!(opts, [:batch_size, :block_size])
block_size = opts[:block_size]
batch_size = opts[:batch_size]
{n} = Nx.shape(data)
# Generate a `{batch_size, 1}` tensor of random numbers between
# `0` and `n  block_size`, they represent the start index for every
# block that composes the batch.
{ix, key} = Nx.Random.randint(key, 0, n  block_size, shape: {batch_size, 1})
# For each number, generate a sequence by summing
# a tensor `{1, block_size}`
# For example, if `block_size` is 8,
# `Nx.iota({1, 8})` will return:
#
# ```
# [
# [0, 1, 2, 3, 4, 5, 6, 7]
# ]
# ```
#
# And `ix + Nx.iota({1, block_size})` will return
# a new tensor `{batch_size, block_size}` because
# the addition will broadcast tensors whenever the
# dimensions do not match and broadcasting is possible.
# https://hexdocs.pm/nx/Nx.html#add/2
x_indices = ix + Nx.iota({1, block_size})
x = Nx.take(data, x_indices)
# Shift by 1 for `y` tensor to
# get the following char.
y = Nx.take(data, x_indices + 1)
{x, y, key}
end
end
seed = 1337
get_streamed_batch = fn split, batch_size, block_size >
Stream.resource(
fn > Nx.Random.key(seed) end,
fn pnrg_key >
data = if split == :train, do: train_data, else: val_data
{xb, yb, new_pnrg_key} =
DataLoader.get_batch(data, pnrg_key, block_size: block_size, batch_size: batch_size)
{[{xb, yb}], new_pnrg_key}
end,
fn _ > :ok end
)
end
# Let's inspect 2 batches `[{xb, yb}, {xb, yb}]`
streamed_batches = get_streamed_batch.(:train, batch_size, block_size)
Enum.take(streamed_batches, 2)
# Hyperparams
batch_size = 32
block_size = 8
# increase number of steps for good results...
iterations = 100
model = BigramLanguageModel.init(vocab_size)
lr = 1.0e3
optimizer = Axon.Optimizers.adamw(lr)
# Use Axon to construct the training loop
params =
model
> Axon.Loop.trainer(&BigramLanguageModel.loss/2, optimizer)
> Axon.Loop.run(streamed_batches, %{}, epochs: 1, iterations: iterations, compiler: EXLA)
(side note: i feel that training loop is a bit slow, but I should verify how long it takes the python implementation. With my notebook (macbook pro intel i7) it takes 6 seconds ca. for 100 iterations.)
# And check the loss with the params after # iterations of training loop
[{xb, yb}] = get_streamed_batch.(:train, batch_size, block_size) > Enum.take(1)
{_logits, loss} = BigramLanguageModel.forward(predict_fn, params, xb, yb)
loss
# And finally generate some text after the training…
# Let's generate more token this time
max_new_tokens = 500
BigramLanguageModel.generate(predict_fn, params, idx, max_new_tokens: max_new_tokens)
> IO.puts()
# For simplicity I let the model train for 10_000 iterations and dump the
# params in a file. Let's load it and see how the text generations looks like…
params10000 =
__DIR__
> Path.join("bigram_model_params_10000")
> Path.expand()
> File.read!()
> Nx.deserialize()
[{xb, yb}] = get_streamed_batch.(:train, batch_size, block_size) > Enum.take(1)
{_logits, loss} = BigramLanguageModel.forward(predict_fn, params10000, xb, yb)
IO.inspect(loss, label: "Loss after 10_000 iterations")
max_new_tokens = 500
BigramLanguageModel.generate(predict_fn, params10000, idx, max_new_tokens: max_new_tokens)
> IO.puts()
The mathematical trick in selfattention (version 1)
# consider the following toy example
# batch, time, channel
# `channel` encodes "some" information at each point in the sequence
{b, t, c} = {4, 8, 2}
{x, _key} = Nx.Random.normal(Nx.Random.key(1337), shape: {b, t, c})
We would like this tokens (up to 8 tokens in a batch) and we would like to couple them and let them communicates in a very specific way: information shuold flow from previous context to the current timestamp and we can not have any information from the future because we are trying to predict them. The simplest way to make the token communicate is to do an average of all the preceeding elements: is a lossy way to make token communication but that’s ok for the moment.
# xbow  bag of words is a term used to identify that we are averaging up things
# Let's initialize to zeroes
xbow = Nx.broadcast(Nx.tensor(0), {b, t, c})
# We want x[b, t] = mean_{i<=t} x[b, i]
xbow =
for b_i < 0..(b  1), reduce: xbow do
xbow >
for t_i < 0..(t  1), reduce: xbow do
xbow >
# Shape `{t^, c}` where `t^` is how many elements there were in the past
# (remember, range is inclusive in Elixir)
xprev = x[[b_i, 0..t_i]]
mean = Nx.mean(xprev, axes: [0])
# `xbow` has rank 3, while `mean` has rank 1.
# In order to inject the computed `mean` in `xbow` tensor
# via `Nx.put_slice/3` we need to make them compatible,
# aka, update `mean` tensor to have the same rank.
wrapped_mean = Nx.reshape(mean, {1, 1, :auto})
Nx.put_slice(xbow, [b_i, t_i, 0], wrapped_mean)
end
end
Let’s take a look at the result:
x[0]
xbow[0]
# Let's "manually" compute the average of the first 2 elements of x (along axis 0)
avg = (Nx.to_number(x[0][0][0]) + Nx.to_number(x[0][1][0])) / 2
# This should be the same of the 2nd element in xbow (along axis 0)
IO.inspect("AVG: #{avg} ~= xbow[0][1][0]: #{Nx.to_number(xbow[0][1][0])}")
:ok
The mathematical trick in selfattention (version 2)
The previous implementation (version 1) is very inefficient. We can be more efficient by using batch matrix multiply in combination with the Nx.tril
function to do this weighted aggregation. And the weights are specified in this {t, t}
tensor (wei
), and we are basically doing weighted sums according to the triangular form of the wei
tensor, so that means that a token at the nth position will only get the information from the preeciding tokens.
# Example to see how bath matrix multiply works
#  the tensor `a` is renamed `t1`
#  the tensor `b` is renamed `t2`
#  the tensor `c` is renamed `t3`
t1 = Nx.broadcast(1.0, {3, 3}) > Nx.tril()
t1 = Nx.divide(t1, Nx.sum(t1, axes: [1], keep_axes: true))
{t2, _key} = Nx.Random.randint(Nx.Random.key(123), 1, 10, shape: {3, 2})
t2 = Nx.as_type(t2, :f32)
t3 = Nx.dot(t1, t2)
IO.inspect(t1, label: "t1 (a)")
IO.inspect(t2, label: "t2 (b)")
IO.inspect(t3, label: "t3 (c)")
:ok
wei = Nx.broadcast(1.0, {t, t}) > Nx.tril()
wei = Nx.divide(wei, Nx.sum(wei, axes: [1], keep_axes: true))
# NOTE:
# We can not simply do `Nx.dot(wei, x)` as in pytorch
# because in this case `Nx.dot/2` does not automatically broadcast
# `wei` across batch dimension `b` of `x` tensor.
#
# Therefore, what we can do is to add the batch dimension
# to our `wei`, from `{t, t}` to `{b, t, t}` by broadcasting it,
# basically it copies the `{t, t}` weights tensor for each batch.
wei = Nx.broadcast(wei, {b, t, t})
# Given that `wei` is `{b, t, t}` and `x` is `{b, t, c}`,
# and that we want the returned tensor to have the same shape of `x`,
# we need to specify the contracting and batch axes
# https://hexdocs.pm/nx/Nx.html#dot/6dotproductbetweentwobatchedtensors
xbow2 = Nx.dot(wei, [2], [0], x, [1], [0])
# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, xbow2)
# Let's manually compare the 1st batch of `xbow` and `xbow2`
IO.inspect(xbow[0], label: "xbow[0]")
IO.inspect(xbow2[0], label: "xbow2[0]")
:ok
Alternative approach using vectorization
Given that we want to do the dot product along the batch dimension b
, we can take advantage of Nx.vectorize
 https://hexdocs.pm/nx/vectorization.html
 https://elixirforum.com/t/understandingnxdot6contractingandbatchaxes/58407/2?u=nickgnd
wei = Nx.broadcast(1.0, {t, t}) > Nx.tril()
wei = Nx.divide(wei, Nx.sum(wei, axes: [1], keep_axes: true))
# Here the magic… no need to specify contracting and batch axes anymore
# when computing the matrix multiplication.
vec_x = Nx.vectorize(x, :batch)
xbow2_vect = Nx.dot(wei, vec_x)
IO.inspect(xbow2_vect)
# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, Nx.devectorize(xbow2_vect, keep_names: false))
The mathematical trick in selfattention (version 3)
Use Softmax to normalize the weights tensor.
# The weights are initialized to 0.0 and we can think of it
# like an interaction strength or affinity, so basically is telling
# us how much from the past token we want to aggregate and average up.
# The weights value is going to change and be data dependent: the tokens will start
# to "look" at each other and find other tokens more or less interesting,
# that's the _affinity_.
wei = Nx.broadcast(0.0, {t, t})
tril = Nx.broadcast(1.0, {t, t}) > Nx.tril()
# In particular, this line is telling us that tokens from the past
# cannot communicate with neg. infinity tokens, basically the
# neg. infinity tokens won't be aggregated
wei = Nx.select(Nx.equal(tril, 0.0), Nx.Constants.neg_infinity(), wei)
IO.inspect(wei, label: "Initial weights")
# Equivalent result by using `Nx.triu/2`
# wei = Nx.broadcast(Nx.Constants.neg_infinity(), {t, t}) > Nx.triu(k: 1)
# Normalize
wei = Axon.Activations.softmax(wei)
IO.inspect(wei, label: "Weights after softmax")
# Aggregation by matrix multiplication `wei @ x`.
# It aggregates their values depending on the affinity
# between tokens.
#
# NOTE:
# We can not simply do `Nx.dot(wei, x)` as in pytorch
# because in this case `Nx.dot/2` does not automatically broadcast
# `wei` across batch dimension `b` of `x` tensor.
#
# Therefore, what we can do is to add the batch dimension
# to our `wei`, from `{t, t}` to `{b, t, t}` by broadcasting it,
# basically it copies the `{t, t}` weights tensor for each batch.
wei = Nx.broadcast(wei, {b, t, t})
# Given that `wei` is `{b, t, t}` and `x` is `{b, t, c}`,
# and that we want the returned tensor to have the same shape of `x`,
# we need to specify the contracting and batch axes
# https://hexdocs.pm/nx/Nx.html#dot/6dotproductbetweentwobatchedtensors
xbow3 = Nx.dot(wei, [2], [0], x, [1], [0])
IO.inspect(xbow3, label: "xbow3")
# Let's check if the result is the same of the
# one obtained before (1 means true)
Nx.all_close(xbow, xbow3)