Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Let's build GPT from scratch! w/ Nx and Axon

gpt-dev.livemd

Let’s build GPT from scratch! w/ Nx and Axon

Mix.install(
  [
    {:nx, "~> 0.5.3"},
    {:req, "~> 0.3.6"},
    {:kino_bumblebee, "~> 0.3.0"},
    {:exla, "~> 0.5.1"},
    {:table_rex, "~> 3.1.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)
:ok

Introduction

This notebook covers Andrej Karpathy’s video Let’s build GPT: from scratch, in code, spelled out. We’ll start off building a simple bigram model, and iteratively build up to the decoder-only transformer.

Note: this notebook was created to experiment with Elixir’s ML libraries, so the following code is probably not idiomatic Nx/Axon code and doesn’t take full advantage of their capabilities.

References

  • Karpathy’s companion notebook can be found here

  • Thanks to Lorenzo Sinisi for the initial livebook code

Prepare data

Let’s first prepare our Shakespeare data

file_path = Path.absname("./input.txt")

text =
  if File.exists?(file_path) do
    IO.puts("File loaded from memory: #{file_path}")
    File.read!(file_path)
  else
    IO.puts(
      "File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    )

    Req.get!(
      "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    ).body
  end
File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger for bread, not in thirst for revenge.\n\nSecond Citizen:\nWould you proceed especially against Caius Marcius?\n\nAll:\nAgainst him first: he's a very dog to the commonalty.\n\nSecond Citizen:\nConsider you what services he has done for his country?\n\nFirst Citizen:\nVery well; and could be content to give him good\nreport fort, but that he pays himself with being proud.\n\nSecond Citizen:\nNay, but speak not maliciously.\n\nFirst Citizen:\nI say unto you, what he hath done famously, he did\nit to that end: though soft-conscienced men can be\ncontent to say it was for his country he did it to\nplease his mother and to be partly proud; which he\nis, even till the altitude of his virtue.\n\nSecond Citizen:\nWhat he cannot help in his nature, you account a\nvice in him. You must in no way say he is covetous.\n\nFirst Citizen:\nIf I must not, I need not be barren of accusations;\nhe hath faults, with surplus, to tire in repetition.\nWhat shouts are these? The other side o' the city\nis risen: why stay we prating here? to the Capitol!\n\nAll:\nCome, come.\n\nFirst Citizen:\nSoft! who comes here?\n\nSecond Citizen:\nWorthy Menenius Agrippa; one that hath always loved\nthe people.\n\nFirst Citizen:\nHe's one honest enough: would all the rest were so!\n\nMENENIUS:\nWhat work's, my countrymen, in hand? where go you\nWith bats and clubs? The matter? speak, I pray you.\n\nFirst Citizen:\nOur business is not unknown to the senate; they have\nhad inkling this fortnight what we intend to do,\nwhich now we'll show 'em in deeds. They say poor\nsuitors have strong breaths: they shall know we\nhave strong arms too.\n\nMENENIUS:\nWhy, masters, my good friends, mine honest neighbours,\nWill you undo yourselves?\n\nFirst Citizen:\nWe cannot, sir, we are undone already.\n\nMENENIUS:\nI tell you, friends, most charitable care\nHave the patricians of you. For your wants,\nYour suffering in this dearth, you may as well\nStrike at the heaven with your staves as lift them\nAgainst the Roman state, whose course will on\nThe way it takes, cracking ten thousand curbs\nOf more strong link asunder than can ever\nAppear in your impediment. For the dearth,\nThe gods, not the patricians, make it, and\nYour knees to them, not arms, must help. Alack,\nYou are transported by calamity\nThither where more attends you, and you slander\nThe helms o' the state, who care for you like fathers,\nWhen you curse them as enemies.\n\nFirst Citizen:\nCare for us! True, indeed! They ne'er cared for us\nyet: suffer us to famish, and their store-houses\ncrammed with grain; make edicts for usury, to\nsupport usurers; repeal daily any wholesome act\nestablished against the rich, and provide more\npiercing statutes daily, to chain up and restrain\nthe poor. If the wars eat us not up, they will; and\nthere's all the love they bear us.\n\nMENENIUS:\nEither you must\nConfess yourselves wondrous malicious,\nOr be accused of folly. I shall tell you\nA pretty tale: it may be you have heard it;\nBut, since it serves my purpose, I will venture\nTo stale 't a little more.\n\nFirst Citizen:\nWell, I'll hear it, sir: yet you must not think to\nfob off our disgrace with a tale: but, an 't please\nyou, deliver.\n\nMENENIUS:\nThere was a time when all " <> ...

Basic Encoder / Decoder

defmodule Minidecoder do
  @chars text |> String.codepoints() |> Enum.uniq() |> Enum.sort()
  @vocab_size Enum.count(@chars)
  def vocab_size, do: @vocab_size

  @stoi Enum.reduce(@chars, %{}, fn ch, acc -> Map.put(acc, ch, Enum.count(acc)) end)
  @itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc -> Map.put(acc, i, ch) end)

  def encode_char(char), do: @stoi[char]

  def decode_char(encoded_char), do: @itos[encoded_char]

  def encode(text) do
    text |> String.codepoints() |> Enum.map(&amp;encode_char(&amp;1))
  end

  def decode(encoded_list) do
    encoded_list |> Enum.map(&amp;decode_char(&amp;1)) |> Enum.join()
  end

  def tensor(text) do
    Nx.tensor(encode(text))
  end
end

vocab_size =
  Minidecoder.vocab_size()
  |> IO.inspect(label: "vocab size is")

Minidecoder.tensor(text)
vocab size is: 65

14:50:17.553 [info] TfrtCpuClient created.
#Nx.Tensor<
  s64[1115394]
  EXLA.Backend
  [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, ...]
>

Encoded Training + Validation Data

data = Minidecoder.tensor(text)
n = Kernel.round(Nx.size(data) * 0.9)
# take from index 0 till the end
train_data = Nx.slice(data, [0], [n])
# take from index 0 for size - n (to get all until end)
val_data = Nx.slice(data, [n], [Nx.size(data) - n])
{train_data, val_data}
{#Nx.Tensor<
   s64[1003855]
   EXLA.Backend
   [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, ...]
 >,
 #Nx.Tensor<
   s64[111539]
   EXLA.Backend
   [0, 0, 19, 30, 17, 25, 21, 27, 10, 0, 19, 53, 53, 42, 1, 51, 53, 56, 56, 53, 61, 6, 1, 52, 43, 47, 45, 46, 40, 53, 59, 56, 1, 14, 39, 54, 58, 47, 57, 58, 39, 8, 0, 0, 14, 13, 28, 32, ...]
 >}

Training Data

To speed up training, we’re going to batch our training data. It’ll look like this.

x = [
  ["h", "e", "l", "l", "o"],
  [" ", "w", "o", "r", "l"]
]

y = [
  ["e", "l", "l", "o", " "],
  ["w", "o", "r", "l", "d"]
]

We’ll insert a linear layer between our x and y. After training, the model should learn these associations

  • “h” -> “e”
  • “e” -> “l”
  • ..
  • “w” -> “o”
  • “o” -> “r”
  • etc

Training Data Generator

The Axon training loop expects an Enumerable or Stream for its training data. We’ll use Stream.resource/3 to repeatedly generate random slices of our training data. Everytime we call it, it’ll also keep track of a random key for the next generation. This ensures reproducible model outputs.

We’ll experiment with different batch sizes and block sizes, so we’ll wrap this Stream in a closure.

seed = 1337

get_batch_stream = fn batch_size, block_size, split ->
  Stream.resource(
    # initialization function
    fn ->
      Nx.Random.key(seed)
    end,
    # generation function
    fn key ->
      data = if(split == :train, do: train_data, else: val_data)

      {ix, new_key} =
        Nx.Random.randint(key, 0, Nx.size(data) - block_size, shape: {batch_size}, type: :u32)

      ix = Nx.to_list(ix)

      x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack()
      y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack()

      # Reshape yb {b, t}, to be a single vector
      # We do this to match the shape of y_true during training
      # https://hexdocs.pm/axon/Axon.Losses.html#categorical_cross_entropy/3
      {b, t} = Nx.shape(y)

      # or Nx.flatten
      flattened_y = Nx.reshape(y, {b * t})

      out_data = {x, flattened_y}

      {[out_data], new_key}
    end,
    # termination function
    fn _ -> :ok end
  )
end

train_batch_stream = get_batch_stream.(4, 8, :train)
train_batch_stream |> Enum.take(1)
[
  {#Nx.Tensor<
     s64[4][8]
     EXLA.Backend
     [
       [46, 47, 51, 57, 43, 50, 44, 1],
       [26, 19, 1, 30, 21, 15, 20, 13],
       [41, 43, 42, 1, 39, 1, 58, 56],
       [1, 42, 53, 1, 46, 43, 56, 43]
     ]
   >,
   #Nx.Tensor<
     s64[32]
     EXLA.Backend
     [47, 51, 57, 43, 50, 44, 1, 40, 19, 1, 30, 21, 15, 20, 13, 30, 43, 42, 1, 39, 1, 58, 56, 39, 42, 53, 1, 46, 43, 56, 43, 6]
   >}
]

Simple Bigram Model

Let’s assume we have a well trained bigram model.

Given an input tensor of size {1, 4} the output might look something like this.

# batch_size = 1, block_size = 4
input = [[h, e, l, l]]

# batch_size = 1, block_size = 4, vocab_size = 65
output = [[[65], [65], [65], [65]]]
  • Each index in these [65] sized tensors correspond to an encoded character from our Shakespeare vocab size
  • The likelihood of an encoded character appearing next in a sequence is given by its value inside the [65] sized tensor.

To predict the next character in our sequence, we’ll look at the last [65] sized tensor in our output. Right now the values are just some raw, non-normalized predictions for our 65 possible characters. We’ll feed this tensor (called logits) into softmax to get a probability distribution that we can sample the next character from.

# Hyperparameters
batch_size = 4
block_size = 8

bigram_model =
  Axon.input("sequence")
  |> Axon.embedding(65, 65)

Axon.Display.as_graph(bigram_model, Nx.template({batch_size, block_size}, :f32),
  direction: :top_down
)
graph TD;
35[/"sequence (:input) {4, 8}"/];
36["embedding_0 (:embedding) {4, 8, 65}"];
35 --> 36;

Training the bigram model

# We'll use this for other models further along in the notebook
defmodule CommonTrain do
  import Nx.Defn

  defn custom_predict_fn(model_predict_fn, params, input) do
    %{prediction: preds} = out = model_predict_fn.(params, input)
    {b, t, c} = Nx.shape(preds)
    reshaped = Nx.reshape(preds, {b * t, c})
    %{out | prediction: reshaped}
  end

  def custom_loss_fn(y_true, y_pred) do
    Axon.Losses.categorical_cross_entropy(y_true, y_pred,
      from_logits: true,
      sparse: true,
      reduction: :mean
    )
  end
end

{init_fn, predict_fn} = Axon.build(bigram_model, mode: :train)
custom_predict_fn = &amp;CommonTrain.custom_predict_fn(predict_fn, &amp;1, &amp;2)
custom_loss_fn = &amp;CommonTrain.custom_loss_fn(&amp;1, &amp;2)
train_batch_stream = get_batch_stream.(4, 8, :train)

params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw())
  |> Axon.Loop.run(train_batch_stream, %{}, epochs: 1, iterations: 10000, compiler: EXLA)

15:25:05.635 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 9950, loss: 2.8657069
%{
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[65][65]
      EXLA.Backend
      [
        [1.954552173614502, -4.385025978088379, -4.377076148986816, -4.385700702667236, -4.394716262817383, -0.9468418955802917, -4.390252590179443, -3.949601650238037, -4.3867998123168945, -2.314349412918091, -4.3820013999938965, -4.39186429977417, -4.384054660797119, 1.2772890329360962, 0.4298079311847687, 0.24700672924518585, -0.18589963018894196, -0.6697914600372314, 0.2670493721961975, -0.1326778531074524, 0.452250212430954, 0.8470256328582764, -1.1264326572418213, -0.38984325528144836, -0.031074173748493195, 0.3561098277568817, -0.02344430610537529, 0.016382480040192604, -0.12044396251440048, -1.0433743000030518, -0.5009018778800964, 0.44456756114959717, 1.4647704362869263, -1.013350009918213, -1.372435212135315, 0.9725027084350586, -4.384570598602295, -0.23563559353351593, -4.383942604064941, -1.2093874216079712, -1.6285748481750488, -1.3509862422943115, -1.6671699285507202, -2.2631280422210693, -1.6090916395187378, -1.9188610315322876, -1.212599277496338, -1.7029807567596436, ...],
        ...
      ]
    >
  }
}

Generating text with the bigram model, w/ argmax

Let’s implement a naive way of generating text using Nx.argmax. Everytime we make a prediction, argmax will pick the highest probable character that our model thinks should be next.

generate_fn = fn model, params, init_seq, max_new_tokens ->
  Enum.reduce(1..max_new_tokens, init_seq, fn _i, acc ->
    {_b, t} = Nx.shape(acc)

    # Cap the input sequence length from [t, block size]
    context_length = min(t, block_size)
    context_range = -context_length..-1
    context_slice = acc[[.., context_range]]

    # Predict next char
    preds = Axon.predict(model, params, context_slice)
    logits = preds[[.., -1, ..]]
    probs = Axon.Activations.softmax(logits)
    # {b, 1}
    batch_char = Nx.argmax(probs, axis: 1, keep_axis: true)

    Nx.concatenate([acc, batch_char], axis: -1)
  end)
end

# init_seq = Nx.broadcast(0, {1, 1})
init_seq = Nx.iota({1, 5})
max_new_tokens = 500

generate_fn.(bigram_model, params, init_seq, max_new_tokens)
# Convert our Nx.tensor to Elixir list
|> Nx.to_list()
# Decode the results
|> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end)
# Our input just 1 batch, so grab the first one
|> List.first()
|> IO.puts()

 !$&cour the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
:ok

Why the repetition?

Lets look inside our linear layer. We should see something like this

  • “t” is likely to produce “h”
  • “h” is likely to produce “e”
  • “e” is likely to produce “ “
  • “ “ is likely to produce “t”
get_top_predictions = fn char, table, num_chars ->
  encoded_char = Minidecoder.encode_char(char)
  predictions = table[encoded_char]

  predictions
  |> Nx.to_list()
  |> Enum.with_index(fn element, index -> {index, element} end)
  |> Enum.map(fn {idx, logit} -> {Minidecoder.decode_char(idx), logit} end)
  |> Enum.sort(fn {_x_idx, x_res}, {_y_idx, y_res} -> x_res >= y_res end)
  |> Enum.take(num_chars)
end

table = params["embedding_0"]["kernel"]

[
  t: get_top_predictions.("t", table, 3),
  h: get_top_predictions.("h", table, 3),
  e: get_top_predictions.("e", table, 3),
  _: get_top_predictions.(" ", table, 3)
]
[
  t: [{"h", 2.5696206092834473}, {" ", 2.204694986343384}, {"o", 1.090488314628601}],
  h: [{"e", 2.453754425048828}, {"a", 1.7331098318099976}, {"i", 1.4247827529907227}],
  e: [{" ", 2.5171456336975098}, {"r", 1.7074227333068848}, {"n", 1.1833429336547852}],
  _: [{"t", 1.9793766736984253}, {"a", 1.3801567554473877}, {"h", 1.2791123390197754}]
]

Multinomial

To avoid repetitive text, we want to randomly sample a character with our model prediction. We can do this with Nx.Random.choice/4. But, our model’s output shape is {b, t, vocab_size}. Pytorch’s torch.multinomial can work with batches, but (afaik) there’s no equivalent function in the Nx library. We’ll need to write a custom function to stack the results of Nx.Random.choice/4

defmodule RandomPlus do
  import Nx.Defn

  defn multinomial(init_key, input, opts \\ []) do
    opts = keyword!(opts, num_samples: 1)
    num_samples = opts[:num_samples]

    {b, c} = Nx.shape(input)
    initial_tensor = Nx.broadcast(0, {b, num_samples})
    category_iota = Nx.iota({c}, type: :s32)

    {_i, _input, next_key, acc} =
      while {i = 0, input, key = init_key, acc = initial_tensor}, i < b do
        # Becomes {C}, represents probability distribution
        i_batch_prob = input[i]

        {i_samples, next_key} =
          Nx.Random.choice(key, category_iota, i_batch_prob, samples: num_samples)

        # Update ith row in acc to hold the new samples
        i_samples_reformatted = Nx.reshape(i_samples, {1, :auto})
        acc = Nx.put_slice(acc, [i, i], i_samples_reformatted)
        {i + 1, input, next_key, acc}
      end

    {next_key, acc}
  end
end

probs =
  Nx.tensor([
    [0.2, 0.3, 0.1, 0.15, 0.25],
    [0.10, 0.10, 0.10, 0.10, 0.60],
    [0.0, 0.0, 0.0, 0.0, 1.00]
  ])

# Given some batched probability distribution, sample 5 values
# This is for demonstration purposes (we'll only need to sample 1 char when generating text)
{_key, samples} = RandomPlus.multinomial(Nx.Random.key(1337), probs, num_samples: 5)
samples
#Nx.Tensor<
  s64[3][5]
  EXLA.Backend
  [
    [1, 4, 0, 4, 4],
    [2, 1, 1, 2, 4],
    [4, 4, 4, 4, 4]
  ]
>

Generating text with the bigram model, w/ multinomial

Let’s see what happens if we use multinomial now.

generate_fn = fn model, params, init_seq, key, max_new_tokens ->
  Enum.reduce(1..max_new_tokens, {key, init_seq}, fn _i, {key, acc} ->
    {_b, t} = Nx.shape(acc)

    # Cap the input sequence length from [t, block size]
    context_length = min(t, block_size)
    context_range = -context_length..-1
    context_slice = acc[[.., context_range]]

    # Predict next batch of chars (when we generate text, batch_size = 1)
    preds = Axon.predict(model, params, context_slice)
    logits = preds[[.., -1, ..]]
    probs = Axon.Activations.softmax(logits)
    {next_key, batch_char} = RandomPlus.multinomial(key, probs, num_samples: 1)

    {next_key, Nx.concatenate([acc, batch_char], axis: -1)}
  end)
  |> then(fn {_next_key, acc} -> acc end)
end

init_seq = Nx.broadcast(0, {1, 1})
key = Nx.Random.key(1337)
max_new_tokens = 1000

generate_fn.(bigram_model, params, init_seq, key, max_new_tokens)
|> Nx.to_list()
|> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end)
|> List.first()
|> IO.puts()

S:
'S: ses t Pis ave, lef. j'ICay thiles Won.AUmoliveng'llofrrouseff he:
Tho ff wous ke
Ker!
lthean areel.
Whon ofilok Alil t thom.Xfabj!

Ifo hendride mou.
TrvicQu ous d mitesors; YCAns qgar matamES:
RISthore qu. ue thuspipKI K:
We che. nd y manthamean the fo,
Ar!
I oout. cowieayouroflllothalveedrgrme
d patit f 3
CK har

UELl;
If chankn ourinowoftipor hendvis?u,
WAgorvan;
Ho hos,
EvS:
TOVck fodonrQ$&-bQhons s her, 'd.

tyolatoresces of$Qy; opy thTho fopum f.
CHNRUCres meowea d s
Thetsos on psth or
PRecon limy t t:
's by indngrs, pXG bntr f hs:
Thays thomea ELESSus cs; at,
t sanLIt,
MI he:
RElide, oppratrmarorige wW:
I as baiak t eind,
HF&CHAtinNIUSI thenane ou I ke hou arou speou!ARE pere at t my ba'3 h brmin ntr alt;
FLenurthourarait:

Hor yo h ctind hadinot:
I theoher. tal t is gx?ws o-
CNRI: tetigQMSeay ifrorer be st ost's wn.
CAifls frin tsovim athe her; ys:
'dWhteyorexBy d y it wegeangr thur y patit, ou.LI3.

Austhar
Gralie:
Wherrvjul; fise s d arerve be'g;
HZXINGinghy!q&GMBY lce
:ok

This looks somewhat better with our limited training. We’ll improve the text generation by focusing on single-head and multi-head attention next.

Before we implement the attention models, let’s create a reusable text generation function for different block sizes (sequence lengths). block_size is required to cap the sequence context for each prediction.

defmodule TextGen do
  def generate(model, params, init_seq, block_size, opts \\ []) do
    opts = Keyword.validate!(opts, key_seed: 1337, max_new_tokens: 1000)

    key = opts[:key_seed] |> Nx.Random.key()
    max_new_tokens = opts[:max_new_tokens]

    Enum.reduce(1..max_new_tokens, {key, init_seq}, fn _i, {key, acc} ->
      {_b, t} = Nx.shape(acc)

      # Cap the input sequence length from [t, block size]
      context_length = min(t, block_size)
      context_range = -context_length..-1
      context_slice = acc[[.., context_range]]

      # Predict next batch of chars (but for us, batch_size = 1)
      preds = Axon.predict(model, params, context_slice)
      logits = preds[[.., -1, ..]]
      probs = Axon.Activations.softmax(logits)
      {next_key, batch_char} = RandomPlus.multinomial(key, probs, num_samples: 1)

      {next_key, Nx.concatenate([acc, batch_char], axis: -1)}
    end)
    |> then(fn {_next_key, acc} -> acc end)
    # Convert our Nx.tensor to Elixir list
    |> Nx.to_list()
    # Decode the results
    |> Enum.map(fn encoded_list -> Minidecoder.decode(encoded_list) end)
    # Our input just 1 batch, so grab the first one
    |> List.first()
  end
end
{:module, TextGen, <<70, 79, 82, 49, 0, 0, 15, ...>>, {:generate, 5}}

The mathematical trick to self attention (version #4)

To implement attention like how Karpathy does it, we’ll create lower triangular matrices filled with ones. Nx doesn’t have an equivalent torch.tril, but we can create these matrices using the iota function.

The iota function is commonly used to create tensors with consecutive values, starting from a specified value and incrementing by one. We can leverage this to create two tensors (row_iota and column_iota) and compare them to create the attention mask.

shape = {3, 3}
row_iota = Nx.iota(shape, axis: 0)
#Nx.Tensor<
  s64[3][3]
  EXLA.Backend
  [
    [0, 0, 0],
    [1, 1, 1],
    [2, 2, 2]
  ]
>
column_iota = Nx.iota(shape, axis: 1)
#Nx.Tensor<
  s64[3][3]
  EXLA.Backend
  [
    [0, 1, 2],
    [0, 1, 2],
    [0, 1, 2]
  ]
>
Nx.greater_equal(row_iota, column_iota)
#Nx.Tensor<
  u8[3][3]
  EXLA.Backend
  [
    [1, 0, 0],
    [1, 1, 0],
    [1, 1, 1]
  ]
>
defmodule Tril do
  import Nx.Defn

  # Creates a lower triangular matrix of 1s to use as our mask 
  defn ones(opts \\ []) do
    assert_keys(opts, [:shape])

    shape = opts[:shape]
    Nx.greater_equal(Nx.iota(shape, axis: 0), Nx.iota(shape, axis: 1))
  end
end

Tril.ones(shape: {5, 5})
#Nx.Tensor<
  u8[5][5]
  EXLA.Backend
  [
    [1, 0, 0, 0, 0],
    [1, 1, 0, 0, 0],
    [1, 1, 1, 0, 0],
    [1, 1, 1, 1, 0],
    [1, 1, 1, 1, 1]
  ]
>

The mathematical trick to self attention (version #4) cont.

Here’s a rough draft of how attention is computed in a single head. We’ll package this up later into a reusable layer

{b, t, c} = {4, 8, 32}
{x, key} = Nx.Random.normal(Nx.Random.key(1337), shape: {b, t, c}, type: :f32)

head_size = 16
# Used for initializing random key, query, value kernels
keys = key |> Nx.Random.split(parts: 3)

# For some reason the default scale (2.0) produces really high weight values
init_fn = Axon.Initializers.he_uniform(scale: 0.5)

key_kernel = init_fn.({c, head_size}, {:f, 32}, keys[0])
query_kernel = init_fn.({c, head_size}, {:f, 32}, keys[1])
value_kernel = init_fn.({c, head_size}, {:f, 32}, keys[2])
k = Axon.Layers.dense(x, key_kernel)
q = Axon.Layers.dense(x, query_kernel)
v = Axon.Layers.dense(x, value_kernel)
kT = Nx.transpose(k, axes: [0, -1, -2])

# {b, t, t}
wei = Nx.dot(q, [2], [0], kT, [1], [0])

# Broadcast tril to {b, t, t} for Nx.select
tril = Tril.ones(shape: {t, t})
tril = Nx.broadcast(tril, {b, t, t})

# Broadcast neg_inf to {b, t, t} for Nx.select
wei_type = Nx.type(wei)
neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(wei_type), wei)

# lower triangular part of wei has original values
# upper triangular part of wei has -neg_inf for its values
wei = Nx.select(tril, wei, neg_inf)

# {4, 8, 8}
wei = Axon.Activations.softmax(wei, axis: -1)

# {4,8,8} @ {4,8,16}
# out = Nx.dot(wei, [-1], [0], v, [1], [0])
# wei[0]
#Nx.Tensor<
  f32[4][8][8]
  EXLA.Backend
  [
    [
      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
      [0.5236416459083557, 0.4763583540916443, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
      [0.1460239738225937, 0.8355455994606018, 0.018430398777127266, 0.0, 0.0, 0.0, 0.0, 0.0],
      [0.2382892519235611, 0.19527707993984222, 0.5128787755966187, 0.053554967045784, 0.0, 0.0, 0.0, 0.0],
      [0.23525264859199524, 0.02861269935965538, 0.12083742767572403, 0.43159112334251404, 0.18370607495307922, 0.0, 0.0, 0.0],
      [0.4105880558490753, 0.009884358383715153, 0.19398914277553558, 9.765126160345972e-4, 0.37614840269088745, 0.008413595147430897, 0.0, 0.0],
      [0.4806952476501465, 0.09604756534099579, ...],
      ...
    ],
    ...
  ]
>

Single-head attention layer

Let’s package up the computation we did earlier into an Axon layer. Since we need to do some custom calculations, we’ll use Axon.layer, you can learn more about custom layers in the Axon docs

defmodule SingleAttention do
  import Nx.Defn

  def head_layer(%Axon{} = key, %Axon{} = query, %Axon{} = value, opts \\ []) do
    Axon.layer(&amp;head_layer_impl/4, [key, query, value],
      name: opts[:name],
      op_name: :single_head
    )
  end

  defn head_layer_impl(k, q, v, _opts \\ []) do
    {_b, t, c} = Nx.shape(k)
    tensor_type = Nx.type(k)

    kT = Nx.transpose(k, axes: [0, -1, -2])

    # {4,8,_16_} @ {4,_16_,8} = {4,8,8}
    wei = Nx.dot(q, [2], [0], kT, [1], [0])

    # Scaled attention
    wei = wei * Nx.rsqrt(c)

    # attention masking
    tril = Tril.ones(shape: {t, t})
    tril = Nx.broadcast(tril, wei)
    neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei)
    # tril, wei, and neg_inf have the shape {b, t, t}
    wei = Nx.select(tril, wei, neg_inf)
    wei = Axon.Activations.softmax(wei, axis: -1)

    Nx.dot(wei, [-1], [0], v, [1], [0])
  end
end
{:module, SingleAttention, <<70, 79, 82, 49, 0, 0, 18, ...>>, true}

Single-head attention model

The general flow of this simple single-head attention model goes like this

  • input sequence is some encoded text
  • map our input into some n_embd dimensional space
  • map some position information into some n_embd dimensional space
    • note: GPT uses sine and cosine functions for positional encoding. we won’t be implementing that
    • our position information is just some iota tensor with values [0..t) where t == sequence length of our input
  • add the tensors to produce a tensor filled with embedding + position information, we’ll call this tensor x
  • feed x into three different layers: key, value, and query
  • compute self attention (this is complicated, this youtube video provides a great explanation)
    • note: For our single-head attention model, the size of the attention head is equal to n_embd.
  • project self attention output down to vocab_size. tensors coming out of this layer will look like {b,t,vocab_size}. These are our logits.

Note: The implementation of the positional_embedding_table is a bit hacky. If somebody knows a better solution, I’d be curious to hear about it.

# Hyperparameters from https://youtu.be/kCc8FmEb1nY?t=4907
batch_size = 32
block_size = 8
n_embd = 32
head_size = n_embd

# Model definition
input = Axon.input("sequence")

token_embedding_table =
  input
  |> Axon.embedding(vocab_size, n_embd)

# Generate positional encodings for the input sequence (hacky)
positions =
  Axon.nx(input, fn input ->
    {_batch_size, sequence_length} = Nx.shape(input)
    Nx.iota({sequence_length})
  end)

# Positional encodings get mapped into @n_embd space
position_embedding_table =
  Axon.embedding(positions, block_size, n_embd, name: "position_embedding")

x_layer = Axon.add(token_embedding_table, position_embedding_table)

he_uniform = Axon.Initializers.he_uniform(scale: 0.5)
key = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "key")
query = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "query")
value = x_layer |> Axon.dense(head_size, kernel_initializer: he_uniform, name: "value")

single_head_model =
  SingleAttention.head_layer(key, query, value)
  |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head")
#Axon<
  inputs: %{"sequence" => nil}
  outputs: "language_modeling_head"
  nodes: 11
>

Single-head attention model graph

Axon.Display.as_graph(single_head_model, Nx.template({batch_size, block_size}, :f32),
  direction: :top_down
)
graph TD;
48[/"sequence (:input) {32, 8}"/];
49["embedding_0 (:embedding) {32, 8, 32}"];
50["nx_0 (:nx) {8}"];
51["position_embedding (:embedding) {8, 32}"];
52["container_0 (:container) {{32, 8, 32}, {8, 32}}"];
53["add_0 (:add) {32, 8, 32}"];
54["key (:dense) {32, 8, 32}"];
55["query (:dense) {32, 8, 32}"];
56["value (:dense) {32, 8, 32}"];
57["single_head_0 (:single_head) {32, 8, 32}"];
58["language_modeling_head (:dense) {32, 8, 65}"];
57 --> 58;
56 --> 57;
55 --> 57;
54 --> 57;
53 --> 56;
53 --> 55;
53 --> 54;
52 --> 53;
51 --> 52;
49 --> 52;
50 --> 51;
48 --> 50;
48 --> 49;

Training the single-head attention model

I lowered the iterations just to speed up training on my machine. You can increase the numbers to get better results.

{init_fn, predict_fn} = Axon.build(single_head_model, mode: :train)
custom_predict_fn = &amp;CommonTrain.custom_predict_fn(predict_fn, &amp;1, &amp;2)
custom_loss_fn = &amp;CommonTrain.custom_loss_fn(&amp;1, &amp;2)
train_data_stream = get_batch_stream.(batch_size, block_size, :train)

params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw())
  |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 3000, compiler: EXLA)

15:28:28.074 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 2950, loss: 2.6103194
%{
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[65][32]
      EXLA.Backend
      [
        [-0.026233026757836342, 0.2256242036819458, 0.3369872272014618, 0.12701235711574554, 0.4724958539009094, 0.4927522838115692, -0.11964629590511322, 0.14982973039150238, 0.1527341604232788, 0.2076384723186493, -0.4839678704738617, 0.5153075456619263, -0.17069879174232483, 0.07722043991088867, -0.024972444400191307, -0.02413080632686615, 0.06961726397275925, -0.12247069180011749, -0.030579620972275734, 0.22726184129714966, -0.29395171999931335, -0.15108881890773773, 0.2248864620923996, 0.30082234740257263, -0.2780728042125702, -0.12545496225357056, -0.1338309943675995, 0.1244623139500618, 0.0377982035279274, 0.10663247853517532, -0.36269432306289673, -0.011278417892754078],
        [-0.03844957798719406, -0.07200898230075836, 0.17476718127727509, -0.13438276946544647, 0.0045249746181070805, -0.042801979929208755, 0.031422700732946396, 0.023107800632715225, 0.2904854416847229, -0.1530412882566452, -0.30277305841445923, 0.16146323084831238, -0.6766874194145203, 0.0561353974044323, -0.017431093379855156, 0.4477277100086212, ...],
        ...
      ]
    >
  },
  "key" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [5.339144845493138e-4, 0.0014086366863921285, -8.214047993533313e-4, -2.5494268629699945e-4, 1.6062534996308386e-4, -2.7622494962997735e-4, -5.117025575600564e-4, -4.4757919386029243e-4, 9.161723428405821e-4, -9.67931846389547e-5, 2.3132111527957022e-4, 2.641436003614217e-4, 0.0013516810722649097, 7.081329822540283e-4, 0.0016453240532428026, -3.157271712552756e-4, 3.824093146249652e-4, -6.019236170686781e-4, 8.673613774590194e-5, -8.229123777709901e-4, 2.3732471163384616e-4, -6.274062325246632e-4, -0.0019465215737000108, -3.8168931496329606e-5, -3.815832678810693e-5, -1.6878465248737484e-4, 3.705104973050766e-5, -1.038895788951777e-5, -6.699645891785622e-4, 6.461592274717987e-4, 2.1109878434799612e-4, 3.498121222946793e-4]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.08415064215660095, 0.2879759967327118, -0.24935944378376007, 0.007627225946635008, -0.3204685151576996, -0.899364709854126, 0.0017727279337123036, -0.40223392844200134, -0.3377895653247833, -0.24244733154773712, -0.22738304734230042, 0.22180818021297455, 0.22660820186138153, -0.4816625416278839, 0.23601596057415009, -0.11853193491697311, 0.3503369390964508, -0.18211157619953156, 0.4438214600086212, -0.8183256387710571, -0.298576682806015, 0.044904839247465134, -0.41670191287994385, 0.18856211006641388, -0.7515907883644104, 0.44034701585769653, 0.01015730295330286, 0.46664172410964966, -0.11723655462265015, -0.24692803621292114, 0.079217828810215, -0.2753612697124481],
        [0.5019965767860413, 0.45832908153533936, -0.5738474726676941, -0.4480721056461334, 0.40961453318595886, -0.02529483661055565, -0.29848185181617737, -0.6949413418769836, 0.2270815372467041, 0.006744408048689365, 0.5465198755264282, 0.4061400890350342, 0.6404085159301758, 0.043757569044828415, ...],
        ...
      ]
    >
  },
  "language_modeling_head" => %{
    "bias" => #Nx.Tensor<
      f32[65]
      EXLA.Backend
      [0.061188384890556335, 0.03874082490801811, -0.3440072238445282, -0.2977518141269684, -0.2497618943452835, -0.12782810628414154, -0.23234425485134125, -0.05885966122150421, -0.29599529504776, -0.2678046226501465, 0.2830163538455963, -0.29441750049591064, -0.3617924749851227, 0.015615391544997692, 0.05543321743607521, 0.059732530266046524, 2.2681929112877697e-4, 0.14945641160011292, -0.027915235608816147, 0.010921932756900787, 0.007053534034639597, 0.17028307914733887, -0.19982320070266724, -5.594325484707952e-4, 0.09194453060626984, 0.060529813170433044, -0.03193450719118118, 0.011559315957129002, -0.009883790276944637, -0.19082780182361603, 0.16168223321437836, -0.019084393978118896, 0.03505204990506172, 0.11881374567747116, -0.09905489534139633, 0.07346461713314056, -0.2048760950565338, 0.006363728549331427, -0.2438381165266037, 0.028238536790013313, -0.0257986132055521, -0.06388751417398453, 0.015318986028432846, 0.028752142563462257, 0.033384814858436584, -0.13916561007499695, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][65]
      EXLA.Backend
      [
        [-0.22565065324306488, -0.16024763882160187, 0.14312513172626495, 0.7375252842903137, 0.6810034513473511, -0.21387973427772522, 0.09498835355043411, -0.16007646918296814, 0.3423072397708893, 0.7293732166290283, 0.4669094383716583, 0.25795578956604004, 0.3606616258621216, -0.15745210647583008, -0.2844543755054474, -0.35039520263671875, 0.06515390425920486, 0.18357262015342712, 0.35768988728523254, -0.12733156979084015, 0.02732028067111969, -0.48360756039619446, -0.03461068868637085, -0.05969618633389473, -0.14592662453651428, 0.060609374195337296, -0.05544782429933548, -0.2138560712337494, 0.11534108221530914, 0.1720675826072693, 0.20025299489498138, 0.125869020819664, 0.020521463826298714, 0.38288405537605286, -0.19809487462043762, -0.06728346645832062, 0.22307513654232025, -0.3827643394470215, 0.056872107088565826, -0.47563377022743225, 0.12778759002685547, -0.13875481486320496, -0.29747429490089417, -0.21030990779399872, -0.021548548713326454, ...],
        ...
      ]
    >
  },
  "position_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[8][32]
      EXLA.Backend
      [
        [0.23624907433986664, -0.021182071417570114, 0.08677957952022552, -0.10049070417881012, -0.05136618763208389, -0.1526702344417572, 0.281761109828949, 0.14189034700393677, -0.02184602990746498, 0.3414129316806793, 0.007780917454510927, -0.17955872416496277, -4.2804042459465563e-4, -0.10944950580596924, -0.11458509415388107, 0.09803333878517151, -0.013304184190928936, 0.2933603525161743, -0.05242357775568962, 0.13128063082695007, -0.07299972325563431, -0.052072543650865555, -0.018173260614275932, 0.04962927848100662, -0.001993720419704914, -0.01199309527873993, 0.13491109013557434, -0.02931121736764908, -0.09115011245012283, -0.05176560580730438, 0.04005185514688492, 0.0767948254942894],
        [0.17488150298595428, -0.037036482244729996, 0.05705530196428299, -0.11671741306781769, -0.039749953895807266, -0.11059848219156265, 0.2114427387714386, 0.04657561331987381, -0.020749395713210106, 0.289633572101593, 0.040833499282598495, -0.13450975716114044, 0.03456944227218628, ...],
        ...
      ]
    >
  },
  "query" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.034071750938892365, -0.28555038571357727, 0.29531151056289673, 0.00276755727827549, 0.4126724600791931, 0.354778915643692, -0.04076027125120163, 0.112078957259655, -0.01396627351641655, 0.010131004266440868, 0.09511833637952805, -0.334270179271698, -0.23874659836292267, 0.11469864100217819, -0.2345762699842453, -0.016118774190545082, -0.08340506255626678, -0.16744185984134674, -0.43535640835762024, 0.4058022201061249, 0.11186989396810532, 0.060189343988895416, 0.4751230776309967, -0.4158278703689575, 0.2699109613895416, 0.026102382689714432, 0.04701479151844978, -0.15357209742069244, 0.11870138347148895, 0.1022055521607399, -0.05558203160762787, 0.2189713716506958]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.18731792271137238, 0.2485392838716507, -0.1639232486486435, -0.11857520788908005, 0.3277408480644226, -0.2164953500032425, -0.061783526092767715, -0.12976013123989105, -0.07613621652126312, -0.22404593229293823, -0.1035078838467598, -0.23175814747810364, -0.5759645700454712, -0.31334614753723145, -0.2222266048192978, 0.0027542654424905777, -0.11652868986129761, -0.3812398314476013, -0.22661450505256653, 0.10356166958808899, 0.1277807503938675, -0.020671674981713295, -8.294901927001774e-4, -0.29352620244026184, -0.14790385961532593, -0.07956714928150177, -0.16293348371982574, 0.14328938722610474, 0.14654004573822021, 0.3405868113040924, -0.056984834372997284, 0.0869172215461731],
        [0.15397128462791443, 0.11139403283596039, -0.14481408894062042, -0.0590481162071228, -0.21032772958278656, -0.25429296493530273, -0.03941408917307854, 0.05786168947815895, -0.37902113795280457, -0.13650669157505035, -0.0036288651172071695, ...],
        ...
      ]
    >
  },
  "value" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.21108631789684296, -9.746256982907653e-4, 0.12640845775604248, -0.02774987183511257, -0.172824889421463, 0.11715316772460938, 0.11779537796974182, 0.03865457698702812, 0.01370930578559637, 0.08419355005025864, 0.05815295875072479, -0.1265273094177246, -0.042062751948833466, -0.13843882083892822, 0.1025078222155571, -0.12344934791326523, -0.019868047907948494, 0.1059509813785553, 0.08318553864955902, 0.007491858210414648, 0.07789982110261917, -0.05841176211833954, -0.08457578718662262, -0.06597407907247543, -0.06657955795526505, -0.100681371986866, -0.2102871686220169, -0.0254219900816679, 0.034235186874866486, -0.20986993610858917, -0.18147148191928864, 4.758408176712692e-4]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.025911198928952217, -0.5517228841781616, -0.3961946666240692, -0.6812283396720886, -0.4301234483718872, 0.06663434207439423, 0.23885639011859894, -0.3583765923976898, 0.20131313800811768, 0.27871912717819214, 0.5112564563751221, -0.7567049264907837, 0.6266830563545227, -0.10098028928041458, -0.24578642845153809, 8.617832645541057e-5, 0.21004417538642883, -0.3550078272819519, -0.19582337141036987, 0.6274057030677795, 0.3725995421409607, -0.2865595817565918, -0.12431655079126358, -0.0691077709197998, -0.05854438990354538, -0.2952299416065216, 0.6136438250541687, 0.2543925344944, 0.37929919362068176, -0.436009019613266, -0.012473562732338905, 0.4972887635231018],
        [0.034896500408649445, -0.011460673995316029, -0.30139535665512085, 0.4306725859642029, 0.3097185790538788, 0.2907136082649231, -0.33310818672180176, -0.06842639297246933, 0.3035053312778473, -0.5042290091514587, ...],
        ...
      ]
    >
  }
}

Generating text w/ single-head attention model

init_seq = Nx.broadcast(0, {1, 1})

TextGen.generate(single_head_model, params, init_seq, block_size, max_new_tokens: 1000)
|> IO.puts()

S:
Be tins te st ave ind goh bkew sis it ay dterimevend mou, titheap heakino gh wous in
Mou duteand bored sthom ome th Cor If thof win m d lso herave sanll istticaveerk fangs wirs hal rt poir mer!

PAROMRUSgeadaveprals, tled ha bes dche the yan, wig, he tir go ait ch orsu, dry,
I: thillllorilay, knesea whe tou g B fiI br Spe thant de iss ouse, wofrorfe hen wot wacs noru k des hsf D wacer sse hikinpe V:
Anons se;
We Afe aythalpre, k tem five: my th nt gres ald
FO:
TOrer mers iMe s
Thesthe on prse or
Sde, iche whicome thay imererohth haghrs hant our wt tit ne Halldou, uncer Ly opiley cilane gn th on:
Loucrse ourm hes I lade be I Tof nechan:
Ce toh ish.
IINPENOERUCALER:
Wie ous st st on.

DULADUCUCHol I a haerire
Ovif ft fimoturs, ss on we I th ys hakord he,
This me ssen mperima t hif ydsw od ghe houlth fhee yon ut, s ch rueler pavik g o-rt hthe susto I the hes owwe benes yseay whady hs whe ingr thur yos toucken whe faivite n hul nd burativit.

Boun rd farist ae gh aned onkd, who fifa le,
:ok

Going from single head to multi head attention

If you’re following along with the video, we’re currently at this part.

Karpathy uses the following OOP code to create multiple heads of attention, but we can do the same thing by reshaping our original key, query, and value layer.

self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

Reshaping our key, query, value layers

Remember our single-head attention layer had k,q,v of size {b, t, n_embd}. In order to compute multi-head attention, we’ll split this n_embd portion into {n_head, head_size}. (head_size is known as hidden_size in other places like Bumblebee). This lets us view the tensor as having multiple heads. For example, let x be some key/query/value tensor.

  1. {b,t,n_embd} = Nx.shape(x)
  2. Reshape the last axis of x to become {b,t,n_heads,head_size}
  3. Transpose x so that the t and n_heads axes become swapped {b, n_heads, t, head_size}
  4. x is now ready to multiply with some other key, query, value tensor that has also undergone this reshaping. When {b, n_heads, t, head_size} @ {b, n_heads, t, head_size}, it’ll produce a tensor of shape {b, n_heads, t, t}. This {t,t} portion is where we apply our attention mask.

This is also how it’s done in the Bumblebee library

# simple multi-head attention without any optimizations like layernorm / feedforward
defmodule MultiAttention do
  import Nx.Defn

  @doc """
  Modified from Bumblebee's transformer.ex 
  https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L525

  Splits the hidden dimension into the given number of attention heads.

  In other words, the input with shape `{batch_size, sequence_length, hidden_size}`
  is reshaped to `{batch_size, sequence_length, num_heads, hidden_size}`.
  Then, transposed to `{batch_size, num_heads, sequence_length, *}` 
  """
  def split_heads(states, num_heads, opts \\ []) do
    opts = Keyword.validate!(opts, name: "split_heads")

    Axon.nx(
      states,
      fn states ->
        batch_size = Nx.axis_size(states, 0)
        sequence_length = Nx.axis_size(states, 1)
        new_shape = {batch_size, sequence_length, num_heads, :auto}

        states
        |> Nx.reshape(new_shape)
        |> Nx.transpose(axes: [0, 2, 1, 3])
      end,
      name: opts[:name]
    )
  end

  def multi_head_layer(%Axon{} = x, num_heads, head_size, opts \\ []) do
    default_initializer = Axon.Initializers.he_uniform(scale: 0.5)
    opts = Keyword.validate!(opts, kernel_initializer: default_initializer)
    initializer = opts[:kernel_initializer]

    key =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "key")
      |> split_heads(num_heads)

    query =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "query")
      |> split_heads(num_heads)

    value =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "value")
      |> split_heads(num_heads)

    Axon.layer(&amp;multi_head_layer_impl/4, [key, query, value], name: "multi_head_attention")
  end

  # Custom layers require the opts argument
  # https://hexdocs.pm/axon/custom_layers.html#creating-custom-layers
  defn multi_head_layer_impl(k, q, v, _opts \\ []) do
    {b, h, t, c} = Nx.shape(k)
    tensor_type = Nx.type(k)

    # {b, h, t, c} @ {b, h, c, t} -> {b, h, t, t}
    wei = Nx.dot(q, [3], [0, 1], k, [3], [0, 1])

    # Scaled attention
    wei = wei * Nx.rsqrt(c)

    # Attention masking
    tril = Tril.ones(shape: {t, t})
    tril = Nx.broadcast(tril, wei)
    neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei)
    wei = Nx.select(tril, wei, neg_inf)
    wei = Axon.Activations.softmax(wei, axis: -1)

    # {b, h, t, t} @ {b, h, t, head_size} -> {b, h, t, head_size}
    out = Nx.dot(wei, [3], [0, 1], v, [2], [0, 1])

    # Tranpose so we can stack the heads on top of each other
    # {b, h, t, c} -> {b, t, h, c}
    out = Nx.transpose(out, axes: [0, 2, 1, 3])

    # Our output tensor is now enriched with attention information
    # We shape it back to {b, t, c}
    # This gives us the proper shape to add to our original input x
    Nx.reshape(out, {b, t, h * c})
  end
end
{:module, MultiAttention, <<70, 79, 82, 49, 0, 0, 26, ...>>, true}

Multi-head attention model

# We'll reuse the hyperparameters from single-head attention model but change head size 
# batch_size = 32
# block_size = 8
# n_embd = 32
n_heads = 4
head_size = div(n_embd, n_heads)

multi_head_model =
  Axon.input("sequence")
  |> then(fn input ->
    # Create an embedding for the input data
    token_embedding_table = Axon.embedding(input, vocab_size, n_embd, name: "token_embedding")

    # Generate positional encodings for the input sequence (hacky, couldn't find alternative)
    positions =
      Axon.nx(input, fn input ->
        {_batch_size, sequence_length} = Nx.shape(input)
        Nx.iota({sequence_length})
      end)

    # Positional encodings get mapped into @n_embd space
    position_embedding_table =
      Axon.embedding(positions, block_size, n_embd, name: "position_embedding")

    # Add the two layers above to produce tensors containing embedding + position info
    Axon.add(token_embedding_table, position_embedding_table, name: "x_positional_encoding")
  end)
  |> MultiAttention.multi_head_layer(n_heads, head_size)
  |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head")
#Axon<
  inputs: %{"sequence" => nil}
  outputs: "language_modeling_head"
  nodes: 14
>

Multi-head attention model graph

Notice how the key, query, value layers split to become 4-dimensional.

Axon.Display.as_graph(multi_head_model, Nx.template({batch_size, block_size}, :f32),
  direction: :top_down
)
graph TD;
59[/"sequence (:input) {32, 8}"/];
60["token_embedding (:embedding) {32, 8, 32}"];
61["nx_0 (:nx) {8}"];
62["position_embedding (:embedding) {8, 32}"];
63["container_0 (:container) {{32, 8, 32}, {8, 32}}"];
64["x_positional_encoding (:add) {32, 8, 32}"];
65["key (:dense) {32, 8, 32}"];
66["split_heads (:nx) {32, 4, 8, 8}"];
67["query (:dense) {32, 8, 32}"];
68["split_heads (:nx) {32, 4, 8, 8}"];
69["value (:dense) {32, 8, 32}"];
70["split_heads (:nx) {32, 4, 8, 8}"];
71["multi_head_attention (:custom) {32, 8, 32}"];
72["language_modeling_head (:dense) {32, 8, 65}"];
71 --> 72;
70 --> 71;
68 --> 71;
66 --> 71;
69 --> 70;
64 --> 69;
67 --> 68;
64 --> 67;
65 --> 66;
64 --> 65;
63 --> 64;
62 --> 63;
60 --> 63;
61 --> 62;
59 --> 61;
59 --> 60;

Training the multi-head attention model

{init_fn, predict_fn} = Axon.build(multi_head_model, mode: :train)
custom_predict_fn = &amp;CommonTrain.custom_predict_fn(predict_fn, &amp;1, &amp;2)
custom_loss_fn = &amp;CommonTrain.custom_loss_fn(&amp;1, &amp;2)
train_data_stream = get_batch_stream.(batch_size, block_size, :train)

params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw())
  |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 3000, compiler: EXLA)

15:05:05.341 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 2950, loss: 2.5504496
%{
  "key" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.002009622985497117, 0.0030817119404673576, 0.002637566765770316, 8.874150225892663e-4, -0.0036697951145470142, 0.0026014624163508415, -8.77177866641432e-4, -8.539229747839272e-4, -0.010210449807345867, 0.0027200556360185146, -0.004561145324259996, 3.5793684219243005e-5, -0.0037038149312138557, -4.817320223082788e-5, 7.681822753511369e-4, -5.375476903282106e-4, -0.004867691546678543, 8.054355857893825e-4, 0.0029539004899561405, -0.0027021823916584253, 0.0010204947320744395, 0.00340187456458807, -0.0045007625594735146, -5.496059893630445e-4, -3.7173699820414186e-4, 0.001366788404993713, 0.0013142818352207541, 4.228678299114108e-4, -0.005013817455619574, 0.006275582127273083, -0.004706717096269131, -7.051487336866558e-4]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.6940197348594666, -0.09190584719181061, -0.7657436728477478, 0.13503895699977875, 0.44908562302589417, -0.8816263675689697, -0.05520651862025261, 0.15477292239665985, 0.569190263748169, -0.18743273615837097, 0.07919623702764511, -0.4803193509578705, 0.3279154300689697, 0.38622552156448364, -0.5753573775291443, -0.5578794479370117, -0.2834005355834961, 0.33334460854530334, 0.26077544689178467, -0.45049938559532166, 0.150202214717865, 0.3185662031173706, -0.030324669554829597, -0.3234182894229889, -0.1026887372136116, -0.19870352745056152, 0.3214745819568634, 0.5666719675064087, -0.3206053674221039, 0.23958717286586761, -0.0799722820520401, 0.4348626732826233],
        [-0.1126786544919014, -0.09792473912239075, 0.9408704042434692, -0.6795709133148193, -0.19755850732326508, 0.2359877973794937, -0.01693005859851837, 0.10840979963541031, -0.250237375497818, -0.8673456907272339, 0.5119004845619202, 0.16578194499015808, -0.11029595881700516, -0.2691425681114197, 0.793876051902771, ...],
        ...
      ]
    >
  },
  "language_modeling_head" => %{
    "bias" => #Nx.Tensor<
      f32[65]
      EXLA.Backend
      [0.08884071558713913, 0.06958547234535217, -0.32006266713142395, -0.30319324135780334, -0.26640912890434265, 0.11584974825382233, -0.14481262862682343, -0.17550241947174072, -0.323355495929718, -0.3057439625263214, 0.1047697365283966, -0.26720258593559265, -0.27166327834129333, 0.11763342469930649, 2.9158510733395815e-4, 0.07884888350963593, -0.03157912194728851, 0.19413268566131592, -0.15487632155418396, 0.010517282411456108, -0.04269542172551155, 0.2341887205839157, -0.1754981130361557, -0.11905181407928467, 0.13274210691452026, 0.0029528785962611437, 0.07284557819366455, 0.1600860059261322, -0.08310021460056305, -0.2266455441713333, 0.25225019454956055, 0.023659436032176018, -0.10211014747619629, 0.1650671660900116, -0.08954048156738281, -0.1098952367901802, -0.2547689974308014, -0.07372691482305527, -0.1818724125623703, 0.12592221796512604, -0.040537815541028976, -0.031494904309511185, 0.027326008304953575, 0.060592930763959885, 0.14337053894996643, -0.13098326325416565, -0.05378049239516258, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][65]
      EXLA.Backend
      [
        [-0.7125491499900818, -0.46839943528175354, -0.2697147727012634, 0.42317184805870056, 0.1936987042427063, -0.38992878794670105, -0.06144484877586365, -0.1245880275964737, -0.14166134595870972, 0.33889245986938477, -0.11005877703428268, -0.20245349407196045, -0.36581820249557495, -0.20540089905261993, -0.21097524464130402, 0.3619268238544464, 0.3771439492702484, 0.06898072361946106, -0.18923085927963257, -0.019612403586506844, 0.2706209719181061, 0.316038578748703, 0.04826759919524193, 0.5845080018043518, -0.23023158311843872, -0.18634116649627686, -0.137278214097023, -0.08182564377784729, 0.5525726675987244, 0.2282651960849762, 0.29154810309410095, -0.18772584199905396, 0.1221810132265091, 0.006088509690016508, 0.20453433692455292, 0.10518457740545273, -0.23493321239948273, 0.36128759384155273, 0.3855167627334595, 0.2623154819011688, -0.2311507612466812, -0.2841350734233856, 0.05035056173801422, 0.12662070989608765, 0.13019651174545288, 0.32261139154434204, ...],
        ...
      ]
    >
  },
  "position_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[8][32]
      EXLA.Backend
      [
        [-0.1628115475177765, 0.09981940686702728, -0.3936239182949066, 0.13354669511318207, 0.1114703118801117, -0.2746349275112152, 0.07311611622571945, -0.08590993285179138, 0.20255635678768158, 0.09707771986722946, -0.07302407920360565, -0.09127600491046906, -0.08702699840068817, -0.15735533833503723, -0.282254159450531, -0.240857794880867, -0.12304611504077911, -0.21165424585342407, 8.840659284032881e-4, -0.012054992839694023, -0.06108580157160759, -0.03755871579051018, 0.3343445956707001, -0.16978250443935394, 0.12578986585140228, -0.0284324511885643, -0.03782382979989052, -0.18365296721458435, 0.02894110605120659, -0.012323955073952675, -0.05722039192914963, -0.34137609601020813],
        [-0.14392606914043427, 0.08419756591320038, -0.4101002812385559, 0.1324985772371292, -0.001000418676994741, -0.15875868499279022, 0.055551934987306595, 0.021683527156710625, 0.0783766582608223, 0.06663555651903152, -0.09983355551958084, -0.1321704238653183, -0.08909580856561661, -0.10228505730628967, ...],
        ...
      ]
    >
  },
  "query" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [0.30744561553001404, -0.40035614371299744, -0.600159227848053, 0.14172907173633575, 0.5269912481307983, -0.4752598702907562, 0.04822060838341713, 0.27492329478263855, 1.0218167304992676, -0.03702900931239128, 0.2820287048816681, -0.2207508534193039, 0.18530860543251038, 0.32630085945129395, -0.33638694882392883, -0.1988237202167511, -0.17499062418937683, 0.6527702808380127, -0.07235650718212128, -0.09215840697288513, -0.11274632811546326, 0.10494064539670944, -0.2722526788711548, -0.5865284204483032, -0.26377955079078674, 0.22697702050209045, -0.48575764894485474, -0.28045541048049927, 0.23497842252254486, -0.31878966093063354, 0.04959775507450104, 0.42344802618026733]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.3361887037754059, 0.6659984588623047, -0.01123120915144682, 0.15361960232257843, -0.8631391525268555, -0.2230737805366516, -0.6785314083099365, 0.3036029636859894, -0.3403639793395996, -0.01705137826502323, -0.30506935715675354, -0.22626237571239471, 0.02393539436161518, 0.22523565590381622, 0.04351476952433586, -0.38311371207237244, 0.24531424045562744, 0.1956530660390854, -0.20927302539348602, 0.24935917556285858, -0.05999097228050232, -0.15577024221420288, 0.4829252362251282, 0.030667277052998543, -0.21252982318401337, -0.3893415629863739, 0.2501755654811859, 0.08658578991889954, -0.07877915352582932, -0.21714217960834503, 0.36494797468185425, 0.41473671793937683],
        [0.6690793037414551, -0.5781891942024231, -0.45157963037490845, 0.06540126353502274, 0.2878626883029938, -0.34218332171440125, -0.06560727953910828, 0.5365062355995178, -0.005141665227711201, -0.9268913269042969, 0.36916038393974304, 0.6808081269264221, ...],
        ...
      ]
    >
  },
  "token_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[65][32]
      EXLA.Backend
      [
        [0.26007893681526184, -0.34626492857933044, 0.02932133339345455, -0.20337288081645966, 0.3459436595439911, 0.04238921031355858, 0.12850745022296906, 0.1433255523443222, -0.14038851857185364, -0.2880373001098633, -0.049702197313308716, -0.06888625770807266, 0.3151685297489166, 0.34581026434898376, 0.05190388113260269, -0.0947747752070427, 0.18888181447982788, 0.1588648557662964, 0.005344870965927839, 0.31078749895095825, 0.35703420639038086, -0.4355168342590332, -0.315531849861145, 0.014955342747271061, 0.20505090057849884, 0.37664178013801575, 0.3216913640499115, -0.02835514210164547, -0.17066606879234314, -0.5279961228370667, -0.07544543594121933, 2.824735129252076e-4],
        [0.19091172516345978, 0.30593302845954895, 0.02575504221022129, -0.3912547528743744, 0.3915358781814575, -0.08316769450902939, 0.23849405348300934, -0.31539011001586914, -0.13143017888069153, 0.014384283684194088, -0.014175496995449066, 0.1446840763092041, ...],
        ...
      ]
    >
  },
  "value" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      EXLA.Backend
      [-0.06650947034358978, -0.07498624920845032, -0.10041021555662155, -0.07517606765031815, -0.017928874120116234, -0.018348190933465958, 0.025849999859929085, 0.024769101291894913, 0.11090978980064392, 0.1781088411808014, 0.1638837605714798, -0.07431018352508545, 0.013822168111801147, 0.07023237645626068, 0.21814927458763123, 0.11975128948688507, 0.03838825970888138, -0.2842266857624054, -0.1748587042093277, 0.1080666333436966, -0.09363164007663727, 0.06730080395936966, -0.05205147713422775, 0.15445540845394135, 0.1409216672182083, 0.02535940520465374, 0.05307505652308464, -0.10651924461126328, 0.08264414966106415, 0.012388779781758785, -0.03457631170749664, 0.054180171340703964]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][32]
      EXLA.Backend
      [
        [0.1357257217168808, -0.3345215320587158, -0.0030786781571805477, 0.9356716871261597, 0.3679511547088623, 0.9085726141929626, -0.2707551121711731, 0.1832733303308487, -0.09216760098934174, 0.46946004033088684, 0.19642551243305206, 0.3595341444015503, -0.005219850689172745, -0.0030760210938751698, 0.14641784131526947, -0.3859373927116394, -0.5769176483154297, -0.4275362491607666, -0.4593873620033264, -0.3144780099391937, -0.08384311199188232, -0.05492587760090828, 0.3388808071613312, -0.40142741799354553, -0.3029492497444153, 0.04506917670369148, -0.2736053764820099, 0.21720239520072937, -0.19470778107643127, 0.4152255356311798, -0.5121222138404846, 0.19692230224609375],
        [0.03740651533007622, -0.3748472332954407, 0.28940436244010925, 0.3966611325740814, 0.23191681504249573, -0.014880786649882793, -0.3704621493816376, -0.08446381241083145, 0.23995298147201538, 0.20541788637638092, ...],
        ...
      ]
    >
  }
}

Generating text w/ multi-head attention model

init_seq = Nx.broadcast(0, {1, 1})

TextGen.generate(multi_head_model, params, init_seq, block_size, max_new_tokens: 1000)
|> IO.puts()

TG:
Bath thod so aug is ngon bisu: kincs awimuoth my so mouctisss, mand dorn go wove in
Mth cuth in bot oferes bole tha pmabueres by.

I
F hou helese s; tie.
Tqua ubmsil,
Mut wiss hal st queraper-MI?

NHesctaquou, u foruss, be Mant de:
Hot by me, of he bush go ait cock: you tus itheepcllkire ewor: ourd
belathe ga De;
Gem
An'dd elu de iss othe trudngeeg heinton, abud thu le it houc hy alrd?
A RRES:
GLD, I tho se, the
Go ayur ford wher ne ith onw thars grhor gh!

HNY wit ment ibe s
Thesthe of prrid m; horin lomy terar thes in. Sor tham bout he whis duvenes ol Galldss dre?
PLAved in,, con hin res,
Se omir tom prir her I k bre aice the.

I chak bye'd the thir is ou Pear hou arribpreng ok I hin brat my baacan broul nut aly fillexrt: so pe, ge ith yran dror, he this me this moflila tane hy yu of gin, to wol hecexhadpredd
Bo rucker quus fich rramy matshrs I the her: yrin hist yount whe yous whe ingr thur you tor one whe faly-thar hon ne.
Nomsowoun hecle se farerve be!

ARNK:
Horecwes fol,
I d
:ok

CheckpointHelper

Training the final multi-layer, multi-head attention model is going to take some time. I got GPT4 to create this CheckpointHelper module to help resume training with the latest checkpoint.

Checkpoints are stored by default in the “checkpoint” directory. In my case, "/Users/[user]/checkpoint/".

defmodule CheckpointHelper do
  def load_last_checkpoint(%Axon.Loop{} = loop, checkpoint_path) do
    with {:ok, last_checkpoint} <- get_most_recent_checkpoint(checkpoint_path) do
      last_state =
        (checkpoint_path <> "/" <> last_checkpoint)
        |> IO.inspect(label: "Resuming training from this checkpoint")
        |> File.read!()
        |> Axon.Loop.deserialize_state()

      Axon.Loop.from_state(loop, last_state)
    else
      _ ->
        IO.puts("Starting training from scratch")
        loop
    end
  end

  defp get_most_recent_checkpoint(dir_path) do
    {:ok, filenames} = File.ls(dir_path)

    filenames
    |> Enum.filter(&amp;String.starts_with?(&amp;1, "gpt_checkpoint_"))
    |> Enum.map(fn filename ->
      [_, checkpoint1, checkpoint2] = Regex.run(~r/gpt_checkpoint_(\d+)_(\d+)/, filename)
      {String.to_integer(checkpoint1), String.to_integer(checkpoint2), filename}
    end)
    |> Enum.max_by(fn {checkpoint1, checkpoint2, _} -> {checkpoint1, checkpoint2} end, fn ->
      nil
    end)
    |> case do
      nil ->
        {:error, "No checkpoint file found"}

      {_, _, filename} ->
        {:ok, filename}
    end
  end
end

checkpoint_path = "checkpoint"

checkpoint_file_pattern = fn %Axon.Loop.State{epoch: epoch, iteration: iter} ->
  "gpt_checkpoint_#{epoch}_#{iter}"
end
#Function<42.3316493/1 in :erl_eval.expr/6>

Scaled up, multi-layer, multi-head attention model

This is the final model w/ the following optimizations

  • feed forward layer
  • residual connections
  • layer norms
  • multi layer blocks

This section covers everything onwards from this point in the video.

defmodule Transformer do
  import Nx.Defn

  def blocks(%Axon{} = x, n_blocks, n_embd, n_head, opts \\ []) do
    opts = Keyword.validate!(opts, dropout_rate: 0.0)

    x =
      for _ <- 1..n_blocks, reduce: x do
        x -> block(x, n_embd, n_head, opts)
      end

    # final layer norm
    x |> Axon.layer_norm(name: "final_block_ln")
  end

  def block(%Axon{} = x, n_embd, n_head, opts \\ []) do
    head_size = div(n_embd, n_head)

    x =
      Axon.add(
        x,
        x |> Axon.layer_norm(name: "block_ln_1") |> multi_head(n_head, head_size, opts),
        name: "x_multihead_attention"
      )

    Axon.add(
      x,
      x |> Axon.layer_norm(name: "block_ln_2") |> feed_forward(n_embd, opts),
      name: "x_feed_forward"
    )
  end

  def feed_forward(%Axon{} = model, n_embd, opts \\ []) do
    opts = Keyword.validate!(opts, dropout_rate: 0.0)

    dropout_rate = opts[:dropout_rate]

    model
    |> Axon.dense(4 * n_embd, kernel_initializer: :he_uniform, name: "feed_forward_dense_1")
    |> Axon.relu(name: "feed_forward_relu")
    |> Axon.dense(n_embd, kernel_initializer: :he_uniform, name: "feed_forward_dense_2")
    |> Axon.dropout(rate: dropout_rate, name: "feed_forward_dropout")
  end

  @doc """
  Modified from Bumblebee's transformer.ex 
  https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/layers.ex#L525

  Splits the hidden dimension into the given number of attention heads.

  In other words, the input with shape `{batch_size, sequence_length, hidden_size}`
  is reshaped to `{batch_size, sequence_length, num_heads, hidden_size}`.
  Then, transposed to `{batch_size, num_heads, sequence_length, *}` 
  """
  def split_heads(states, num_heads, opts \\ []) do
    opts = Keyword.validate!(opts, name: "split_heads")

    Axon.nx(
      states,
      fn states ->
        batch_size = Nx.axis_size(states, 0)
        sequence_length = Nx.axis_size(states, 1)
        new_shape = {batch_size, sequence_length, num_heads, :auto}

        states
        |> Nx.reshape(new_shape)
        |> Nx.transpose(axes: [0, 2, 1, 3])
      end,
      name: opts[:name]
    )
  end

  def multi_head(%Axon{} = x, num_heads, head_size, opts \\ []) do
    default_initializer = Axon.Initializers.he_uniform(scale: 0.5)
    opts = Keyword.validate!(opts, kernel_initializer: default_initializer, dropout_rate: 0.0)
    initializer = opts[:kernel_initializer]
    dropout_rate = opts[:dropout_rate]

    key =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "key")
      |> split_heads(num_heads)

    query =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "query")
      |> split_heads(num_heads)

    value =
      x
      |> Axon.dense(num_heads * head_size, kernel_initializer: initializer, name: "value")
      |> split_heads(num_heads)

    Axon.layer(&amp;multi_head_layer_impl/4, [key, query, value],
      name: "multi_head_attention",
      dropout_rate: dropout_rate
    )
    |> Axon.dense(num_heads * head_size, name: "multi_head_dense")
    |> Axon.dropout(rate: dropout_rate, name: "multi_head_dropout")
  end

  # Custom layers require the opts argument
  # https://hexdocs.pm/axon/custom_layers.html#creating-custom-layers
  defn multi_head_layer_impl(k, q, v, opts \\ []) do
    opts = keyword!(opts, mode: :train, dropout_rate: 0.0)
    dropout_rate = opts[:dropout_rate]

    {b, h, t, c} = Nx.shape(k)
    tensor_type = Nx.type(k)

    # {b, h, t, c} @ {b, h, c, t} -> {b, h, t, t}
    # 
    # Alternatively we could have done 
    # kT = Nx.transpose(k, axes: [0, 1, 3, 2])
    # wei = Nx.dot(q, [3], [0, 1], kT, [2], [0, 1])
    wei = Nx.dot(q, [3], [0, 1], k, [3], [0, 1])

    # Scaled attention
    wei = wei * Nx.rsqrt(c)

    # Attention masking
    tril = Tril.ones(shape: {t, t})
    tril = Nx.broadcast(tril, wei)
    neg_inf = Nx.broadcast(Nx.Constants.neg_infinity(tensor_type), wei)
    # tril, wei, and neg_inf have the shape {b, h, t, t}
    # Nx.select will look at tril, and if true it'll pick the value from wei, else -infinity
    wei = Nx.select(tril, wei, neg_inf)
    wei = Axon.Activations.softmax(wei, axis: -1)
    wei = Axon.Layers.dropout(wei, Nx.Random.key(1337), rate: dropout_rate)

    # {b, h, t, t} @ {b, h, t, head_size} -> {b, h, t, head_size}
    out = Nx.dot(wei, [3], [0, 1], v, [2], [0, 1])

    # Tranpose so we can stack the heads on top of each other
    # {b, h, t, c} -> {b, t, h, c}
    out = Nx.transpose(out, axes: [0, 2, 1, 3])

    # Our output tensor is now enriched with attention information
    # We shape it back to {b, t, c}
    # This gives us the proper shape to add to our original input x
    Nx.reshape(out, {b, t, h * c})
  end
end
{:module, Transformer, <<70, 79, 82, 49, 0, 0, 36, ...>>, true}

Because our model is much larger now, learning_rate is lowered to 3.0e-4 .

# Hyperparameters
n_embd = 384
n_heads = 6
n_layer = 6
batch_size = 64
block_size = 256
learning_rate = 3.0e-4
dropout_rate = 0.2

final_model =
  Axon.input("sequence")
  |> then(fn input ->
    # Create an embedding for the input data
    token_embedding_table = Axon.embedding(input, vocab_size, n_embd, name: "token_embedding")

    # Generate positional encodings for the input sequence (hacky, couldn't find alternative)
    positions =
      Axon.nx(input, fn input ->
        {_batch_size, sequence_length} = Nx.shape(input)
        Nx.iota({sequence_length})
      end)

    # Positional encodings get mapped into @n_embd space
    position_embedding_table =
      Axon.embedding(positions, block_size, n_embd, name: "position_embedding")

    # Add the two layers above to produce tensors containing embedding + position info
    Axon.add(token_embedding_table, position_embedding_table, name: "x_positional_encoding")
  end)
  |> Transformer.blocks(n_layer, n_embd, n_heads, dropout_rate: dropout_rate)
  |> Axon.dense(vocab_size, kernel_initializer: :he_uniform, name: "language_modeling_head")
#Axon<
  inputs: %{"sequence" => nil}
  outputs: "language_modeling_head"
  nodes: 122
>

Training the final model

With the current hyperparameters, each checkpoint comes out to be ~35mb.

{init_fn, predict_fn} = Axon.build(final_model, mode: :train)
custom_predict_fn = &amp;CommonTrain.custom_predict_fn(predict_fn, &amp;1, &amp;2)
custom_loss_fn = &amp;CommonTrain.custom_loss_fn(&amp;1, &amp;2)
train_data_stream = get_batch_stream.(batch_size, block_size, :train)

params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Axon.Optimizers.adamw(learning_rate))
  |> CheckpointHelper.load_last_checkpoint(checkpoint_path)
  |> Axon.Loop.checkpoint(
    event: :iteration_completed,
    filter: [every: 99],
    path: checkpoint_path,
    file_pattern: checkpoint_file_pattern
  )
  |> Axon.Loop.run(train_data_stream, %{}, epochs: 1, iterations: 5000, compiler: EXLA)
Resuming training from this checkpoint: "checkpoint/gpt_checkpoint_0_3018"

09:10:13.246 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 4998, loss: 1.6417795
%{
  "block_ln_1" => %{
    "beta" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.00848830584436655, 0.00920316856354475, -0.011059202253818512, 0.020641149953007698, -0.004101151134818792, -0.007295509800314903, 0.008314329199492931, -0.008258385583758354, -8.087782771326602e-4, -0.007050180342048407, -0.01070280559360981, -0.023406347259879112, -0.0016787968343123794, -0.02222963236272335, -0.018540432676672935, -0.014302549883723259, 0.0010177168296650052, -0.01171213947236538, -0.03551258519291878, 0.005983751267194748, -0.006882576737552881, 9.761084947967902e-5, -0.016758816316723824, -0.0011584166204556823, 0.0017746005905792117, 0.00811840035021305, 0.01556096225976944, 0.033429522067308426, 4.5731314457952976e-4, 0.03519672155380249, 0.006238017231225967, 0.009631684981286526, 4.665410378947854e-4, 0.01423694659024477, 0.002984068589285016, -0.015089771710336208, -0.009606064297258854, -0.00905038882046938, 0.0021043112501502037, -0.0035458712372928858, -0.00983726978302002, 2.740136696957052e-4, 0.017948541790246964, 0.0016870932886376977, 0.020347923040390015, 0.004460370168089867, 0.0073865666054189205, 0.004740085918456316, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.9019647836685181, 0.2373071312904358, -1.074188232421875, 0.3899938464164734, 1.0010205507278442, -1.3070733547210693, 0.8236031532287598, -1.5869724750518799, 0.3163173794746399, 0.37490928173065186, -1.0106483697891235, -1.6686797142028809, -0.5751791000366211, -0.16168655455112457, -0.9891197085380554, 1.5283355712890625, 0.2562783658504486, -0.8124969601631165, 0.08245331048965454, 1.4095937013626099, 1.1405019760131836, -1.6745530366897583, 1.2952544689178467, 1.4876749515533447, 1.5021872520446777, 0.6464712023735046, -0.9438899755477905, 1.5004093647003174, -1.1250152587890625, 1.5461266040802002, 1.5317260026931763, 0.5766931772232056, 1.33409583568573, -0.8737964630126953, 1.491490364074707, 1.16155207157135, 1.2941070795059204, -0.34952089190483093, -1.407163143157959, 0.8760281205177307, 1.24746572971344, 0.38950851559638977, -0.05163421109318733, -0.3753258287906647, -0.7462378740310669, -0.33452585339546204, 0.5205916166305542, ...]
    >
  },
  "block_ln_2" => %{
    "beta" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [0.09962860494852066, -0.017158204689621925, 0.064473956823349, -0.02154291793704033, -0.1019430160522461, 0.013523696921765804, -0.012999583967030048, -0.03896573930978775, 0.1074555367231369, -0.1628476083278656, -0.05008017644286156, 0.0365133099257946, -0.07730609178543091, -0.02009524218738079, -0.03905002772808075, 0.09997482597827911, 0.02380087599158287, -0.10680756717920303, -0.07425318658351898, -0.004472446162253618, -0.1200273409485817, 0.006688457913696766, 0.06873513758182526, -0.06357701867818832, 0.2842353880405426, -0.07767970114946365, -0.0018610170809552073, 0.0703798159956932, 0.00796227715909481, -0.08657971024513245, -0.041982825845479965, -0.0027350035961717367, -0.022983459755778313, 0.009639413096010685, 0.013106227852404118, 0.07589122653007507, 0.04959215968847275, 0.011667381040751934, 0.04105760157108307, -0.00228147697634995, 0.003996665123850107, 0.009335983544588089, -0.03350013121962547, -0.03160631284117699, -0.07728160172700882, 0.018959442153573036, 0.0063129304908216, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.3861040472984314, 1.3981317281723022, -0.6260970830917358, -0.6142591238021851, -0.8400639891624451, 1.2417875528335571, 1.3667564392089844, -1.4321521520614624, 0.6886146068572998, -0.6659373641014099, -1.0708991289138794, -1.5967121124267578, 0.6946384906768799, -1.511633038520813, -1.143491268157959, -0.28134071826934814, 1.3280084133148193, 0.4598590135574341, 1.032454013824463, 1.591091275215149, -1.516690969467163, -0.6980851292610168, 0.31910091638565063, 1.2580981254577637, 0.21601411700248718, 0.9868366122245789, -1.5262973308563232, -1.100521206855774, 1.4690897464752197, 1.5698038339614868, -1.5286346673965454, 1.5577421188354492, -0.9915542602539062, 1.2179222106933594, -0.8456666469573975, -1.5218002796173096, 0.7244951725006104, -1.3286036252975464, -0.7319031953811646, -0.41700422763824463, 1.5427435636520386, 0.3690219521522522, -1.1647601127624512, 0.3217754364013672, -0.486890971660614, -1.1719893217086792, ...]
    >
  },
  "feed_forward_dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1536]
      EXLA.Backend
      [-0.01969454623758793, -0.03141949325799942, -0.025403395295143127, -0.02523775026202202, -0.033035438507795334, -0.018221471458673477, -0.027275601401925087, -0.02603720873594284, -0.014015775173902512, -0.024843944236636162, -0.03239646926522255, 0.005450837314128876, -0.006688214372843504, -0.023177066817879677, -0.033121153712272644, -0.012314669787883759, -0.022113187238574028, -0.014615594409406185, -0.023663213476538658, -0.031118979677557945, -0.033788520842790604, -0.027982935309410095, -0.03504092991352081, -0.006818012334406376, -0.018050672486424446, -0.021601490676403046, -0.02082575485110283, -0.011461683548986912, -0.03900982439517975, -0.023099983111023903, -0.03569479286670685, -0.016298258677124977, -0.01670851558446884, -0.031987082213163376, -0.027163656428456306, -0.03729814663529396, -0.02999606914818287, -0.03916657716035843, -0.025162482634186745, -0.0090325390920043, -0.04438036307692528, -0.013018188066780567, -0.029345678165555, -0.020096443593502045, -0.04386697709560394, -0.03509372100234032, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][1536]
      EXLA.Backend
      [
        [-0.10871043801307678, -0.09207568317651749, -0.08655396848917007, -0.10224799066781998, 0.041993580758571625, 0.12063898146152496, -0.046707216650247574, -0.051344893872737885, -0.12153489142656326, -0.06452394276857376, 0.08251994103193283, 0.0873091071844101, -0.20266391336917877, 0.048171039670705795, -0.00336627708747983, 0.049141667783260345, -0.019730960950255394, -0.00961564015597105, -0.011603367514908314, 0.004130497574806213, 0.0031170370057225227, -0.16046537458896637, -0.08627355843782425, -0.08215445280075073, -0.006033711135387421, 0.06354783475399017, -0.046216171234846115, -0.011816021986305714, -0.12118562310934067, -0.14242969453334808, -0.1388159543275833, 0.045520272105932236, 0.014241612516343594, 0.01466137170791626, -0.1673451066017151, -0.007058283314108849, 0.06803500652313232, 0.0638517439365387, 0.08884952962398529, -0.07472492754459381, -0.14217063784599304, -0.007931041531264782, 0.0704694539308548, -0.08046635240316391, -0.07672581821680069, ...],
        ...
      ]
    >
  },
  "feed_forward_dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [0.0016378882573917508, -0.004477905575186014, -1.8500685109756887e-4, -0.0018909699283540249, -0.0018954131519421935, -9.799797553569078e-4, 0.0021089836955070496, -0.0036012891214340925, -0.005732030142098665, -0.013728820718824863, -0.00945020467042923, 2.4839743855409324e-4, 0.005067904945462942, 6.146890227682889e-4, -0.004659520462155342, -0.0016149275470525026, 0.01016254909336567, 0.01562158390879631, -0.003967766184359789, 0.0025086686946451664, -0.003208652837201953, -0.00295314472168684, 8.695174474269152e-4, 0.003091400722041726, -0.020229540765285492, 0.0010084941750392318, -9.657987975515425e-4, -4.3867313070222735e-4, 3.0992846586741507e-4, 0.008426363579928875, -0.002736120019108057, 1.1287703091511503e-4, -6.791693158447742e-4, 4.408113891258836e-4, -0.001615250133909285, 0.007265539839863777, -0.004101641941815615, 0.0012738561490550637, 0.004863182548433542, 0.005860744044184685, -0.004074892494827509, -0.002494568470865488, 0.003334584180265665, -5.73037366848439e-4, -0.0058544655330479145, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[1536][384]
      EXLA.Backend
      [
        [0.07813902199268341, 0.00755926501005888, -0.01507661771029234, -0.008303034119307995, -0.02480701357126236, -0.04982953518629074, 0.0688544288277626, -0.00950665958225727, 0.009715151973068714, 0.04819542542099953, 0.0017397699411958456, -0.027666132897138596, -1.672387879807502e-4, -0.02711641602218151, -0.02916313335299492, -0.004612234886735678, 0.04952307417988777, -0.013753768056631088, 0.0031526777893304825, -0.009431369602680206, -0.03181751072406769, 0.008109191432595253, 0.02372041903436184, -0.0030595185235142708, -0.019107308238744736, 0.035727448761463165, 0.03700413927435875, -0.020368410274386406, -0.012909585610032082, -0.004915925208479166, -0.003388757584616542, -0.03510107100009918, 0.010471213608980179, -0.030938591808080673, -0.010781565681099892, 0.05147552490234375, -0.004744974430650473, -0.05841310694813728, 0.05155558884143829, 0.02477145381271839, -0.06363201886415482, 0.06570444256067276, -0.043092936277389526, 0.0038584712892770767, ...],
        ...
      ]
    >
  },
  "feed_forward_dropout" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [722951111, 1788778939]
    >
  },
  "final_block_ln" => %{
    "beta" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.06223426014184952, 0.020014429464936256, 0.04373729228973389, -0.02258773148059845, -0.07341495156288147, 0.038936272263526917, -0.008139075711369514, 0.05396490544080734, -0.08630179613828659, -0.03686339408159256, -0.027823925018310547, -0.03715555742383003, -0.019566871225833893, -0.05063457041978836, -0.0014520023250952363, 0.003898404538631439, 0.021072369068861008, -0.07472241669893265, -0.08343469351530075, 0.06747688353061676, -0.022054431959986687, -0.02674838900566101, -0.012110492214560509, 0.016023002564907074, 0.09339968115091324, -0.07446654140949249, 0.006534852087497711, 0.024841567501425743, -0.01911127381026745, -0.09875895082950592, 0.052235424518585205, -0.04918821528553963, -0.01366699393838644, 0.023073343560099602, 0.0037924626376479864, -0.05260546877980232, 0.007962469011545181, -0.011324295774102211, -0.029835710301995277, -0.025004083290696144, -0.07888715714216232, -0.0416419580578804, -0.020500805228948593, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [0.6214079856872559, 0.7886702418327332, -1.302459478378296, 1.4649631977081299, -0.5696701407432556, 0.525750458240509, -1.1974927186965942, -1.1313458681106567, -0.8776533603668213, -0.31107765436172485, 1.6163456439971924, -0.300855427980423, 0.7801929116249084, 1.1120021343231201, -1.1623111963272095, -1.131649136543274, -1.1163313388824463, 0.38161373138427734, 0.7284241914749146, -0.48665928840637207, -0.9602446556091309, 0.40392425656318665, -1.0368990898132324, -1.0550838708877563, 0.1489740014076233, 0.7694688439369202, -1.1132980585098267, -0.1496853530406952, -1.478727102279663, 1.6444339752197266, 1.2406306266784668, -0.6676476001739502, 1.1959826946258545, -0.7487120628356934, -0.44440358877182007, 0.6873772740364075, -0.852936863899231, 1.1269590854644775, 1.5348412990570068, 1.0658866167068481, -1.0694810152053833, -1.0851964950561523, ...]
    >
  },
  "key" => %{
    "bias" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [1.1622644524322823e-4, 6.414170638890937e-5, 8.051748591242358e-5, 1.2625248928088695e-4, -2.94942146865651e-5, -1.432016579201445e-4, 9.895324183162302e-6, -1.1200238986930344e-5, -4.871409691986628e-5, -7.922281656647101e-5, 4.16673174186144e-5, 1.8111472309101373e-4, 1.1267856461927295e-4, 1.5345354040618986e-4, 2.838678192347288e-4, 4.03459052904509e-5, 1.536956369818654e-5, 1.5929706569295377e-4, -1.9008279195986688e-4, -2.74067249847576e-4, 1.6437079466413707e-5, -5.9939222410321236e-5, -1.0931974247796461e-4, 2.542962320148945e-4, -1.5804167196620256e-4, -5.212277756072581e-5, 5.324201993062161e-5, 1.9963234080933034e-4, -2.5228006416000426e-4, -8.598788554081693e-5, 4.858178726863116e-5, 1.8963949696626514e-4, -1.971950987353921e-4, 6.570507684955373e-5, 1.5816971426829696e-4, 3.819910125457682e-5, 1.2098293518647552e-4, 3.7578868796117604e-4, -3.22144478559494e-4, -1.3360384036786854e-4, -1.2937923020217568e-4, 2.4096581910271198e-4, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][384]
      EXLA.Backend
      [
        [-0.0012871787184849381, 0.05325806885957718, -0.0329965278506279, -9.909607470035553e-4, 0.010785852558910847, 0.009750726632773876, 0.04176070913672447, -0.054770056158304214, 0.008932334370911121, 0.022378653287887573, -0.04128154367208481, 0.05506803095340729, 0.0077568707056343555, -0.027852976694703102, 0.0058920662850141525, -0.0761546641588211, 0.017340153455734253, 0.022978365421295166, 0.025580445304512978, -0.018654203042387962, 0.04613232612609863, 0.04237061366438866, 0.012818633578717709, -0.011534139513969421, 0.07500981539487839, -0.05386031046509743, 0.03983011469244957, 0.06280945241451263, -0.05516345426440239, -0.007531187497079372, -0.01810508966445923, -0.05361419916152954, 0.04882905259728432, 0.06248072162270546, -0.028620455414056778, -0.03912936523556709, -0.039756715297698975, -0.004793907981365919, 0.03730688616633415, 0.031572699546813965, 0.03633453696966171, ...],
        ...
      ]
    >
  },
  "language_modeling_head" => %{
    "bias" => #Nx.Tensor<
      f32[65]
      EXLA.Backend
      [-0.002946221036836505, 0.024192655459046364, -0.05523858591914177, -0.17861443758010864, -0.15369747579097748, 0.02711259014904499, -0.007696119602769613, -0.03897085785865784, -0.04317887872457504, -0.1140795648097992, -0.018143698573112488, -0.08973734080791473, -0.08171650767326355, -0.019000938162207603, -0.0456165112555027, -0.03262511268258095, -0.05419778451323509, -0.01939394138753414, -0.056382980197668076, -0.0516597144305706, -0.041264262050390244, -0.009129352867603302, -0.03057103417813778, -0.07836691290140152, -0.012997974641621113, -0.019970541819930077, -0.05059399455785751, -0.034343983978033066, -0.017009710893034935, -0.11167945712804794, -0.015476089902222157, -0.016781035810709, -0.0243771243840456, -0.07288465648889542, -0.102816142141819, -0.04766138643026352, -0.2592124044895172, -0.05409780517220497, -0.14828692376613617, 0.03934876248240471, 0.018621230497956276, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][65]
      EXLA.Backend
      [
        [0.015758411958813667, -0.051811583340168, 0.1745842546224594, 0.03205735236406326, 0.14337486028671265, 0.06512445956468582, 0.08498262614011765, 0.01941721700131893, 0.04333372786641121, 0.11364075541496277, -0.043731216341257095, 0.09409287571907043, 0.06420766562223434, 0.0015034792013466358, 0.009646364487707615, 0.08266307413578033, 0.12922188639640808, 0.12079586833715439, -0.00773721095174551, 0.11592798680067062, 0.0960366353392601, 0.0038370585534721613, 0.2200704663991928, 0.09468023478984833, 0.15574730932712555, 0.14135101437568665, 0.05444718524813652, -0.049071311950683594, 0.19212448596954346, 0.2701287865638733, -0.014314115978777409, 0.1655369997024536, 0.1066943034529686, 0.11247802525758743, 0.1863957941532135, 0.13626587390899658, 0.15401582419872284, 0.10896278917789459, 0.12454552948474884, -0.09278301894664764, ...],
        ...
      ]
    >
  },
  "multi_head_dense" => %{
    "bias" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.013091623783111572, -0.0025277109816670418, -0.015406734310090542, -0.0055812350474298, 0.0070198820903897285, 0.013889962807297707, 0.005282912403345108, -0.005262902472168207, 0.007651640567928553, 0.0013492725556716323, -6.565082585439086e-4, -0.009683044627308846, -0.020997583866119385, -0.0025598604697734118, 0.0017333345022052526, 1.242513917532051e-6, 0.0014633577084168792, -0.0045132143422961235, -0.011305914260447025, -0.0038131410256028175, -4.433437716215849e-4, -7.106841658242047e-4, 0.013996715657413006, -0.0048048049211502075, 0.016807060688734055, 0.005172553937882185, -0.008687887340784073, 0.011892814189195633, -0.0028974039014428854, 0.0022457861341536045, 0.007068112958222628, 0.002055160701274872, -0.002763577038422227, 0.014206371270120144, -0.005901937372982502, -0.0054986197501420975, -0.007799994666129351, 0.0023439626675099134, 0.009106948971748352, -0.012797119095921516, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][384]
      EXLA.Backend
      [
        [0.07698854804039001, 0.009023042395710945, -0.026208246126770973, -0.006611100863665342, 0.036692507565021515, 0.05781245604157448, 0.04712959751486778, 0.024427007883787155, 0.05523138493299484, 0.01334020122885704, -0.014303312636911869, -0.029595771804451942, 0.032298460602760315, 0.036912381649017334, -0.09651049226522446, 0.061845824122428894, -0.05959019064903259, -0.02616739273071289, -0.07068630307912827, -0.06196698546409607, -0.05527180805802345, -0.05376365780830383, -0.06438268721103668, -0.06433121114969254, -0.017974428832530975, -0.008260016329586506, 0.03823176771402359, 0.09679777175188065, -0.03601941838860512, 0.017709065228700638, -0.07428305596113205, -0.047941479831933975, -0.027466345578432083, -0.042206089943647385, 0.010596836917102337, 0.048028383404016495, 0.06414726376533508, -0.028881270438432693, 0.06266971677541733, ...],
        ...
      ]
    >
  },
  "multi_head_dropout" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [3215365877, 2835071777]
    >
  },
  "position_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[256][384]
      EXLA.Backend
      [
        [0.032874882221221924, 0.017140092328190804, 0.003328746184706688, -0.005748175550252199, 0.017624981701374054, -0.01301295030862093, -0.006106921937316656, -0.006924390327185392, -0.0665484294295311, -0.008325970731675625, -0.01804683916270733, 0.010912414640188217, 0.03563772514462471, 0.01872558705508709, -0.004042694810777903, 0.025369824841618538, 0.0023279483430087566, -0.035076919943094254, -0.005552679765969515, -0.004224944394081831, 0.017052853479981422, -0.009900875389575958, -0.0076743196696043015, -0.023836085572838783, 0.014323941431939602, -0.0012540258467197418, 0.00770062068477273, -0.05086039751768112, -0.0807521864771843, -0.024993926286697388, -0.011912805959582329, 0.00384755851700902, 0.03233299031853676, -0.015088425949215889, 0.030430546030402184, 0.0033960388973355293, 0.02266317792236805, -0.02247624099254608, ...],
        ...
      ]
    >
  },
  "query" => %{
    "bias" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.019059835001826286, 0.007212303578853607, -0.0253163930028677, 0.02480863407254219, 0.030279699712991714, 0.044939201325178146, -0.010317720472812653, -0.03053799644112587, 0.01236018631607294, 0.01944728195667267, -0.05433334782719612, -0.1049325242638588, 0.021181730553507805, -0.026806436479091644, 0.004234259016811848, 0.00791642814874649, 0.0014239175943657756, 0.024933788925409317, 0.012788806110620499, 0.050823960453271866, 0.0026838139165192842, 2.7709786081686616e-4, -6.114842108217999e-7, -0.049711503088474274, -0.04625536501407623, 0.018391326069831848, -0.03626689687371254, -0.07920512557029724, 0.10758330672979355, -0.012113719247281551, -0.04403815045952797, -0.05166914314031601, 0.012584244832396507, -0.019165927544236183, 0.0462692454457283, 0.021622730419039726, 0.0011790632270276546, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][384]
      EXLA.Backend
      [
        [0.0058362302370369434, 0.018301762640476227, -0.018645629286766052, -0.038221556693315506, 0.06795700639486313, 0.027598747983574867, -0.03910763934254646, -0.02521691471338272, -0.011773718520998955, 0.0824030414223671, -0.0028445073403418064, 0.006800362840294838, -0.02214733138680458, -0.09010455012321472, -0.032811108976602554, 0.08154575526714325, 0.025752056390047073, 0.08471326529979706, -0.008311114273965359, 0.052244335412979126, 0.036650706082582474, 0.04471902549266815, -0.023855609819293022, -0.004607589449733496, -0.011002322658896446, 0.0587620735168457, 0.005509072449058294, -0.01659712754189968, 0.026653604581952095, 0.08233808726072311, -9.325456921942532e-4, 0.03881856054067612, -0.035048775374889374, 0.032658956944942474, -0.035093218088150024, 0.0045005446299910545, ...],
        ...
      ]
    >
  },
  "token_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[65][384]
      EXLA.Backend
      [
        [-0.020142383873462677, -0.014539672993123531, 0.0122062424197793, 0.039099209010601044, 0.03575901687145233, -0.00782727263867855, -0.01719738356769085, 0.07246894389390945, 0.02138376235961914, 0.032816097140312195, 0.012480519711971283, -0.03317803516983986, 0.027605293318629265, -0.017775360494852066, 4.3115095468237996e-4, 0.0056412131525576115, 0.036179959774017334, -0.010553312487900257, 5.860661040060222e-4, 0.03476720303297043, -0.00231738667935133, 0.005250687710940838, -0.014498109929263592, 0.010408789850771427, -0.016012923792004585, -0.012880049645900726, 0.02018360234797001, 0.007029877044260502, 0.00606964435428381, -0.0016668542521074414, 0.007597022689878941, -0.012955783866345882, 0.015751812607049942, 0.002011312637478113, -0.0200498066842556, 0.0019235415384173393, ...],
        ...
      ]
    >
  },
  "value" => %{
    "bias" => #Nx.Tensor<
      f32[384]
      EXLA.Backend
      [-0.0030991090461611748, -0.002844809088855982, -0.009626881219446659, 4.500289505813271e-4, 0.00581051129847765, 1.5979257295839489e-4, 0.01220391783863306, -0.00541482400149107, 0.003090923186391592, -8.587435586377978e-4, 0.006976135075092316, 0.009952674619853497, -0.007886230014264584, 0.006909910123795271, 0.0038270035292953253, -0.003139507956802845, -0.0021739653311669827, 0.00220508617348969, -0.0015421948628500104, -0.004189274273812771, -0.007480615749955177, 0.004374688025563955, 0.0025047780945897102, -0.004456990864127874, -0.0062192585319280624, -0.009940357878804207, 0.007896310649812222, -0.008275106549263, 0.011483363807201385, 0.012841667048633099, -0.00428217276930809, 0.0053891753777861595, 0.0035085384733974934, 0.0021641673520207405, -5.1330698624951765e-6, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[384][384]
      EXLA.Backend
      [
        [0.009906159713864326, 0.019326014444231987, 0.013990432024002075, 0.010629404336214066, -0.004779709968715906, -0.04365978389978409, 0.02479981817305088, 0.03628024086356163, -0.010248202830553055, 0.03433370590209961, -0.06440325081348419, 0.01333934348076582, 0.05220233276486397, -0.07441692799329758, 0.07505417615175247, -0.08211889117956161, -0.015163464471697807, -0.04528651386499405, -0.026584822684526443, 0.005662917625159025, 0.004619475454092026, 0.049791473895311356, 0.011024191975593567, -0.03539152815937996, 0.0017584519227966666, -0.016710760071873665, 0.07067379355430603, 0.006168636493384838, 0.041405051946640015, 0.024044036865234375, 0.04775714501738548, -0.04292962700128555, 0.010696981102228165, 0.015617704018950462, ...],
        ...
      ]
    >
  }
}

Generating text with the final model

init_seq = Nx.broadcast(0, {1, 1})

TextGen.generate(final_model, params, init_seq, block_size, max_new_tokens: 10000)
|> IO.puts()

Second Servingman:
Come, for a pull, and as as to my presence;
Vouchle?

POMPEY:
Why, lady.

ROMEO:
Come, sir?

MERCUTIO:
Have speak, Make my considerer: therefore.

ROMEO:
Why, the gate?

MERCUTIO:
And to the pretty took'd doable.

MERCUTIO:
Nay, be more, and after and you true, villain.

ROMEO:
Why, Come, that Camillo,
As come valianted to my through in my power;
Nor, if our to God beggar, ever, call arm;
And is idle autor to loss.

MERCUTIO:
What, do you have no with not make I nothing?

ROMEO:
'Tis that honour let honour? what do good he?

ROMEO:
Thou ha?

PETRUCHIO:
Have me lord?

MERCUTIO:
I pray you pretty what bear be give so?

MERCUTIO:
My lord?

MERCUTIO:
All I pray you, poor lord, I clond a be done.

ROMEO:
He throw with insife that I am speak and she consul.

ROMEO:
I'll cut see not to you have: if your death soon.
If me sent my like you brother.

TYRREL:
Why, my son your will my continue, to-morrow,
And you do my grace? and, what so do mightly be him
In he rests warm jame a sacre, his with his eEven being;
As blow's it a lack mine ere the court in he sill.

KING RICHARD III:

Keeper:
Troth, will and safety, if 'gain.

QUEEN ELIZABETH:
Then and his comes myself, resign of the hands
In honour sit, have me othing
To six this at I am no beats, made and from the
In may betterfly tarry this the one: set himself
To look'd it of the great he dear:
Where, but a gates beseech none stain'd,
For shine is death very and come in honour'd;
I love as a welcomed your vain a play'd.
For, I claim to not stander the betwixt thousands?

BUCKINGHAM:
Because by those good my true old I should
To with this my prince and have bite:
Thou wilt noble immed, I wouldering and itself
To have been in pence for young young by should be offend?
To the your a title be shall and drink you,
Too your being you thus next upon the gods wretch!
It is tell despatch you do you to look,
I pay forsake hope, my lord. Come, and for this you?

ANGELO:
When this is is your amorous dry edies you:
I leave no banished my kind his with the purpose.

ISABELLA:
The should you not a quarry command expose:
You not he when I need.

ISABELLA:
Ay, sir, that have worship, so you, that for King
To tooch will perpeted passister him,
And just shall be king it for your fly;
No sorrow yours are of my business are weat
Is not attend in the king. You must need
To prove that is a garmy some a little
I am in thee.

This prayers to they head; but ell to be
A be so pretty could d be roccursed it.

TRANIO:
As elder amongsy:
Pardon father, sir; shall never hither's fack
Thus are head as that does the war from of he
in me possessioner's your love.

MERCUTIO:
A monger, intent for this protector a knew me
too his sincely and fadies their
luke earn: end to you to coldier thanks 'll.

ROMEO:
Come, and thy about my come violent.

MERCUTIO:
Why, too: thou art that my why charge,
as I am date, by my destraight take with decrew.

ROMEO:
Such I sprepare is the is dead! I dare me devision
solegiance the wish the live hence.

ERCHIOLO:
Thou art shalt ne'er reason bosom of a stone.

POMPEY:
He shalt be the gods him go. Do you villain.
Did I shall had ere insmies what many,
It is prince pursuits I have my namel.

PETRUCHIO:
The so to my lord.
Your grace man arrow it now.

EDWARD:
Not hate.

BUCKINGHAM:
My lords, Lord Bohemia: he much have gone.

GLOUCESTER:

CLARENCE:
Do now, to him be flyman Salisbury,
And to comfort you this daughter lords your be not
How your kindtily to be caasters of this:
It warrance this like yours so bring:
Grow break like unfant to your as England?

KING EDWARD IV:
Because to stroke a father.

GLOUCESTER:
Now, so much as you? what salls. As we's lady?

LADY GREY:
'Tis was a trouble of thy good?

KING EDWARD IV:
Shall thy calls? thou have drops thy for her?

GLOUCESTER:
Whose come would Christ I be longs, thy lord!

QUEEN MARGARET:
Why, thou wolf, what forgive rich me?

CLARENCE:
Why this in Gaunt at being thy mother?

QUEEN MARGARET:
And dispite revenge thou canst thou, royalty
Is thou both these plant thy hand, thy news:
I thoughts of thy hearts of lovest thy tongue
The from wounds purpose of and enough;
For thy blood this goble and again a prove?

QUEEN ELIZABETH:
Therefore that I thanks, the were did o'er port,
Hath colour had made mightly humble pierceed
That I several this hath plagued to dispair,
Of this dust thou art thou wilt be matter.
But not speak what thou hast he truth rend
That horses Brottom'd!
Why, not I know in thou redeem, and what to thy love?

CATESBY:
We is't?

CATESBY:
He Pray most thou wast thou dost so know thy return'd?

PRINCE EDWARD:
And I cannot ashe bring did thy counter.
Thou thy succester, not what my souls,
MI'll blood without to heart thy purposes agest by
To late come, thy scorn thy be giving me will
A kein a better, whose man's of thy wants-pert,
Who dares up thy with the and sugger things mights thee same.
Kill's a deceives often that bear, to friend,
To country's with slain my hand, 'Bring subject between in
That mistress heir put draws.

Shepherd:
And thou not is the babe'st grace in yours.

POMPEY:
I would be no encountenance your be no shall
question; sit out knows false to the greate 't.

PERDITA:
Sir, your hated any of your dreamd.

PETUS:
What, the Tush, pity
What, ortly thou may shalt her; I meal, for your son?
The slip wert thou this prince wish, who but jest?

MENENIUS:
He's stay there?

COMINIUS:
Sir, you speak tone spirit of this; I thoughts
Where not goest the should been
Be though welcome too sancture of Coriolanus,
Exposition, and discover number: your she common
Make done put you men counts prophecy be from father.

OXFORD:
I think foul this beggar-conquer where against I am to
Lonf, but 'tiss to't: look'd way' humble.

Second Servingman:
Why, withy cannot is goodly 'tis with the he,
daughter: you must entreat hath use no pale.

LEONTES:
How?

More harse! I'll men together carved be his guilty!

PRINCE EDWARD:
O, why give made this it? what they are than's one?
Come is is country's you will tars, hast staint hire,
That have broughts for that would know dost forew.

GLOUCESTER:
My heart, Decliff, my father some, and in three.

ANGELO:
She hath all they lord; for I know my move not friend.

LADY ANNE:
No, It will to thinks me to not for no love.

LADY ANNE:

GLOUCESTER:
An enever me with the king?

GLOUCESTER:
I did, I mother? methough ofury, we by my heart
Whom the no more of much well as sendings;
It cannot live thinks by my with a peace,
And the right wings fair that we wilt the true.'

LEONTES:
O Pray you,
To reclaim break, thou hast love. I was flesh all
Accompanion: many three lord and only to
As think.

MENENIUS:
Take run'd to throne with is truth, to him.

MENENIUS:
Methy knees:
Shall have him how walkly, that chequality one
And let of your his sondition, to I come go:
My breath-perform your for your good soul's to us.

MENENIUS:
Must hold, therefore issues! you have destrance.

MARCIUS:
The clamations' you, you stand we wit.

VIMILIA:
Speak abhavours? pray none when I came gople.

VOLUMNIA:
He hath bed me the srunds pluck'd the cause nothing?

VOLUMNIA:
O my lord.

PERDITA:
I pray, and that say you know not, you must dieath
shall forth my brother.

MENENIUS:
Nay, now notha second from that die?

MENENIUS:
The kill our say 'I' the word, I am,'tis
Of your tiding rest yours, I could scalutchy's my
I will stone.

CORIOLANUS:
Hath senator possible.

First Marcius;
will'd the to in't will not the and lawful give
Second all reputy the marting shall thine for die.

COMINIUS:
Nay, not not he still unto this calarench a good.

MARCIUS:
Villain, brail, with sin, blow the thee;
But name, let hollow, my lords; and will you have
I'll resign, it is the still I saw young him
some forswear of children: but for well be
My tedious should not to too, sistake too a.
What was not? thy tongue heart: sit but a France?

CLIFFORD:
Mightier to be thverence!

FRIAR LAURENCE:
Bid, he should that nor many book'd not in,
Who confesserved chamber'd my scient his gurden me?

ROMEO:
A good More under of my lordship me,
For what she scratch will out lord words?

BENVOLIO:
O, crying on, then set thanks being; not of, attrous
Of you are of you weaky, wherefore you to them?

ROMEO:
Thou art shighness is your eye; not they soest approve
Of thy me in thy done sad done's a word interch,
Liewis is like a happy 'Twixt downos under,
Lest the mine enemy the head thy mock make thee.

ROMEO:
Do there in the present daughter:
Madam, good night adversity; he's it my lord.

HERMIONE:
Marry, when thou know me abstard, my lie?

MERCUTIO:
To stay, never king, sir, by me, and that
Is love, my life.

CAPULET:
My lord.

ROMEO:
Auth hast is too, and served my lord.

PRINCE EDWARD:
But hear it me harm.

GREMIO:
Not very that forsaken thy mother,
my name not me thy cannot thy me,
Thy blood in that one sound not to theirs.

DUKE OF AUMERLE:
Well, or did to be man; but thou love,
Tower it mistory, sdisconcil thou did his most
And more piercharge a tewell, therefore wile
Of so so. Both prison! what thy sir, but thou,
Does that; but ere which is is need, we be
At thou bears wert pergination a poor such to thy friends
So infactorse, here bearts and here is should before?
On plotting but thou see thy fair objected.
This is own confess upon-thou sift; but thy king,
To heard rocker had body I believer yet remember
As if thy war stand traitors: that death blush,
I would me consul! straight, let thee news.

POLIXENES:
Why, I am so slain that?
Thy hear me, do my life ten'd to thy duty.

HERMIONE:
The should I reple king!
No through is is a serves for think; to the world:
My lordship to thy king, now heart, Do not cheek nor hand,
And what thou diest thy this sound to pay thee have rank'd;
And their as herefore arms are match'd bring Warwick
O' the nuptain the executions with this thee:
Take candh prophecy me sun schange thou speech fraughters.
And in he on God's draction the kings,
And who can a wword the of king? not lacks,
God deed's daughter the king to dry grave:
H
:ok