Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 4: Implementing a GPT model from scratch to generate text

ch4.livemd

Chapter 4: Implementing a GPT model from scratch to generate text

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:tiktoken, "~> 0.3.2"},
  {:table_rex, "~> 3.1.1"},
  {:kino_vega_lite, "~> 0.1.11"}
])

4.1 Coding an LLM architecture

LLMs, such as GPT (which stands for generative pretrained transformer), are large deep neural network architectures designed to generate new text one word (or token) at a time.

In the context of deep learning and LLMs like GPT, the term “parameters” refers to the trainable weights of the model. These weights are essentially the internal variables of the model that are adjusted and optimized during the training process to minimize a specific loss function. This optimization allows the model to learn from the training data.

gpt_config_124m = [
  attn_name: "attention_0",
  vocab_size: 50257,
  context_length: 1024,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.1,
  qkv_bias: false
]
[
  attn_name: "attention_0",
  vocab_size: 50257,
  context_length: 1024,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.1,
  qkv_bias: false
]
defmodule DummyGPTModel do
  def model(config \\ []) do
    Axon.input("sequence")
    |> Axon.embedding(config[:vocab_size], config[:emb_dim])
    |> Axon.embedding(config[:context_length], config[:emb_dim])
    |> Axon.dropout(rate: config[:drop_rate])
  end
end
{:module, DummyGPTModel, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:model, 1}}
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

# gpt2 not supported
{:ok, ids1} = Tiktoken.encode("gpt-3.5-turbo", txt1, [])
{:ok, ids2} = Tiktoken.encode("gpt-3.5-turbo", txt2, [])

tensors = Enum.map([ids1, ids2], &amp;Nx.tensor/1)
Nx.stack(tensors)
#Nx.Tensor<
  s64[2][4]
  [
    [11769, 5149, 11031, 499],
    [11769, 1938, 10187, 264]
  ]
>
batch = Nx.tensor([[6109, 3629, 6100, 345], [6109, 1110, 6622, 257]])
#Nx.Tensor<
  s64[2][4]
  [
    [6109, 3629, 6100, 345],
    [6109, 1110, 6622, 257]
  ]
>

The model outputs, which are commonly referred to as logits.

4.2 Normalizing activations with layer normalization

Training deep neural networks with many layers can sometimes prove challenging due to problems like vanishing or exploding gradients. These problems lead to unstable training dynamics and make it difficult for the network to effectively adjust its weights, which means the learning process struggles to find a set of parameters (weights) for the neural network that minimizes the loss function.

key = Nx.Random.key(123)
Nx.Random.normal(key, shape: {2, 5}) |> IO.inspect()

batch_example =
  Nx.tensor([
    [-0.1115, 0.1204, -0.3696, -0.2404, -1.1969],
    [0.2093, -0.9724, -0.7550, 0.3239, -0.1085]
  ])
{#Nx.Tensor<
   f32[2][5]
   [
     [-0.5154414772987366, -0.8975640535354614, 1.9826834201812744, -1.9789758920669556, -2.8818085193634033],
     [-0.6626349687576294, -0.03326578065752983, -0.1879543960094452, -0.6107876896858215, 0.16164331138134003]
   ]
 >,
 #Nx.Tensor<
   u32[2]
   [1896456402, 17229315]
 >}
#Nx.Tensor<
  f32[2][5]
  [
    [-0.11150000244379044, 0.12039999663829803, -0.36959999799728394, -0.24040000140666962, -1.1969000101089478],
    [0.2092999964952469, -0.9724000096321106, -0.7549999952316284, 0.3239000141620636, -0.10849999636411667]
  ]
>
model =
  Axon.input("input", shape: {nil, 5})
  |> Axon.dense(6)
  |> Axon.activation(:relu)

{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, 5}, :f32)
params = init_fn.(template, %{})
result = predict_fn.(params, batch_example) 
#Nx.Tensor<
  f32[2][6]
  [
    [0.0, 0.0, 1.0354856252670288, 0.6478695869445801, 0.5209736824035645, 0.0],
    [0.0, 0.44917839765548706, 0.07561643421649933, 0.0, 0.0, 0.440552294254303]
  ]
>
Nx.mean(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Nx.variance(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [0.36738815903663635],
    [0.16089119017124176]
  ]
>
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [0.1589224636554718],
    [0.04104159399867058]
  ]
>
#Nx.Tensor<
  f32[2][1]
  [
    [0.1589224636554718],
    [0.04104159399867058]
  ]
>
Nx.add(1, Nx.tensor([1,2]))
#Nx.Tensor<
  s64[2]
  [2, 3]
>
defmodule TransformerLayers.Norm do
  import Nx.Defn

  def normalization(%Axon{} = input, opts \\ []) do
    opts = Keyword.validate!(opts, [:name, :eps, :emb_dim])
    eps = Keyword.get(opts, :eps, 1.00e-5)
    scale = Axon.param("scale", {opts[:emb_dim]}, initializer: &amp;ones(&amp;1, type: &amp;2))
    shift = Axon.param("shift", {opts[:emb_dim]}, initializer: &amp;zeros(&amp;1, type: &amp;2))

    Axon.layer(
      &amp;normalization_impl/4,
      [input, scale, shift],
      name: opts[:name],
      op_name: :normalization,
      eps: eps
    )
  end

  defp ones(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(1)
  end

  defp zeros(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(0)
  end

  defnp normalization_impl(input, scale, shift, opts \\ []) do
    mean = Nx.mean(input, axes: [-1], keep_axes: true)
    variance = Nx.variance(input, axes: [-1], keep_axes: true)
    denominator = variance |> Nx.add(opts[:eps]) |> Nx.sqrt()

    input
    |> Nx.subtract(mean)
    |> Nx.divide(denominator)
    |> Nx.multiply(scale)
    |> Nx.add(shift)
  end
end
{:module, TransformerLayers.Norm, <<70, 79, 82, 49, 0, 0, 18, ...>>, true}

The variable eps is a small constant (epsilon) added to the variance to prevent division by zero during normalization. The scale and shift are two trainable parameters (of the same dimension as the input) that the LLM automatically adjusts during training if it is determined that doing so would improve the model’s performance on its training task. This allows the model to learn appropriate scaling and shifting that best suit the data it is processing.

model =
  Axon.input("input", shape: {nil, 5})
  |> Axon.dense(6)
  |> Axon.activation(:relu)
  |> TransformerLayers.Norm.normalization(emb_dim: 6, name: "norm")

{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, 5}, :f32)
params = init_fn.(template, %{})
result = predict_fn.(params, batch_example) 
#Nx.Tensor<
  f32[2][6]
  [
    [-0.9361081719398499, -0.9361081719398499, 0.28018441796302795, 0.631147563457489, 1.7666065692901611, -0.8057222366333008],
    [2.2003324031829834, -0.5372314453125, -0.5372314453125, -0.051406439393758774, -0.5372314453125, -0.5372314453125]
  ]
>
Nx.mean(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Nx.variance(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [-4.967053879312289e-9],
    [3.0423205288343524e-8]
  ]
>
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [0.9999224543571472],
    [0.9997627139091492]
  ]
>
#Nx.Tensor<
  f32[2][1]
  [
    [0.9999224543571472],
    [0.9997627139091492]
  ]
>
model =
  Axon.input("input", shape: {nil, 5})
  |> TransformerLayers.Norm.normalization(emb_dim: 5, name: "norm")

{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, 5}, :f32)
params = init_fn.(template, %{})
result = predict_fn.(params, batch_example) 
#Nx.Tensor<
  f32[2][5]
  [
    [0.5527316331863403, 1.0693719387054443, -0.02227856032550335, 0.26556071639060974, -1.86538565158844],
    [0.9086875319480896, -1.3767629861831665, -0.9563034772872925, 1.1303280591964722, 0.2940508723258972]
  ]
>
Nx.mean(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Nx.variance(result, axes: [-1], keep_axes: true) |> IO.inspect(label: "Variance")
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [1.527368986842248e-8],
    [0.0]
  ]
>
Variance: #Nx.Tensor<
  f32[2][1]
  [
    [0.9999502301216125],
    [0.9999626278877258]
  ]
>
#Nx.Tensor<
  f32[2][1]
  [
    [0.9999502301216125],
    [0.9999626278877258]
  ]
>

4.3 Implementing a feed forward networkwith GELU activations

Historically, the ReLU activation function has been commonly used in deep learning due to its simplicity and effectiveness across various neural network architectures. However, in LLMs, several other activation functions are employed beyond the traditional ReLU. Two notable examples are GELU (Gaussian error linear unit) and SwiGLU (Swish-gated linear unit).

GELU and SwiGLU are more complex and smooth activation functions incorporating Gaussian and sigmoid-gated linear units, respectively. They offer improved performance for deep learning models, unlike the simpler ReLU.

The GELU activation function can be implemented in several ways; the exact version is defined as GELU(x) = x⋅Φ(x), where Φ(x) is the cumulative distribution function of the standard Gaussian distribution.

x = Nx.linspace(-3, 3, n: 100) 
y_gelu = Axon.Activations.gelu(x) |> Nx.to_flat_list
y_relu = Axon.Activations.relu(x) |> Nx.to_flat_list
x = Nx.to_flat_list(x)
[-3.0, -2.939393997192383, -2.8787879943847656, -2.8181817531585693, -2.757575750350952,
 -2.696969747543335, -2.6363635063171387, -2.5757575035095215, -2.5151515007019043,
 -2.454545497894287, -2.39393949508667, -2.3333332538604736, -2.2727272510528564,
 -2.2121212482452393, -2.151515007019043, -2.090909004211426, -2.0303030014038086,
 -1.9696969985961914, -1.9090908765792847, -1.848484754562378, -1.7878787517547607,
 -1.7272727489471436, -1.6666666269302368, -1.60606050491333, -1.545454502105713,
 -1.4848484992980957, -1.424242377281189, -1.3636362552642822, -1.303030252456665,
 -1.2424242496490479, -1.1818181276321411, -1.1212120056152344, -1.0606060028076172, -1.0,
 -0.9393939971923828, -0.8787877559661865, -0.8181817531585693, -0.7575757503509521,
 -0.6969695091247559, -0.6363635063171387, -0.5757575035095215, -0.5151515007019043,
 -0.4545454978942871, -0.3939392566680908, -0.33333325386047363, -0.27272725105285645,
 -0.21212100982666016, -0.15151500701904297, -0.09090900421142578, -0.030303001403808594, ...]
alias VegaLite, as: Vl
Vl.new(title: "GeLU vs ReLU", width: 400, height: 400)
|> Vl.data_from_values(x: x, relu: y_relu, gelu: y_gelu)
#|> Vl.mark(:line)

|> Vl.layers([
  Vl.new()
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "relu", type: :quantitative),
  Vl.new()
  |> Vl.mark(:line)
  |> Vl.encode(:color, value: "#db646f")
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "gelu", type: :quantitative)
])

The smoothness of GELU can lead to better optimization properties during training, as it allows for more nuanced adjustments to the model’s parameters. In contrast, ReLU has a sharp corner at zero, which can sometimes make opti mization harder, especially in networks that are very deep or have complex architec tures. Moreover, unlike ReLU, which outputs zero for any negative input, GELU allows for a small, non-zero output for negative values. This characteristic means that during the training process, neurons that receive negative input can still contribute to the learning process, albeit to a lesser extent than positive inputs.

emb_dim = 768

model =
  Axon.input("sequence", shape: {nil, 3, emb_dim})
  |> Axon.dense(4*emb_dim)
  |> Axon.activation(:gelu)
  |> Axon.dense(emb_dim)

template = Nx.template({2, 3, emb_dim}, :f32)
Axon.Display.as_graph(model, template)
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
params = init_fn.(template, %{})
#result = predict_fn.(params, batch_example)
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.003304421901702881, -0.016809184104204178, -0.03325856477022171, -0.0021312350872904062, 3.2139767426997423e-4, 0.010116472840309143, 0.006222399417310953, 0.012478764168918133, 0.03872491046786308, -3.6608779919333756e-4, -0.02619709074497223, -0.038599178194999695, 0.025230957195162773, 0.02970854751765728, -0.024508705362677574, -0.026699841022491455, 0.025043629109859467, 0.0072991191409528255, 0.03803646191954613, -0.0064300550147891045, 0.017563346773386, -0.013632065616548061, -0.028153730556368828, 0.0033491686917841434, -0.008048352785408497, -0.03783676028251648, -0.02570010907948017, 0.013254857622087002, 0.007934308610856533, 0.014547789469361305, 0.004512054845690727, -0.0023731763940304518, -0.03366536647081375, 0.013989351689815521, 0.01460570190101862, -0.010568341240286827, -0.02975991927087307, 0.03243562579154968, -0.01234765350818634, 0.02312454581260681, 0.028792832046747208, -0.019612740725278854, 0.03310718387365341, 0.0019981071818619967, 0.02614656835794449, -0.012234391644597054, 0.024715153500437737, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.034608639776706696, 0.010364653542637825, -0.030205633491277695, 0.02597096376121044, 0.01090461015701294, -0.025229506194591522, 0.012717727571725845, -0.003160559805110097, 0.016355035826563835, -0.003989749122411013, -0.038452357053756714, -0.015247460454702377, -0.03277323395013809, 0.021599048748612404, -6.608995026908815e-4, -0.03222370147705078, 0.030917847529053688, 0.032683778554201126, -0.03478296846151352, -0.020255206152796745, 0.001228243694640696, 0.034070707857608795, -0.0030467514880001545, 0.015283499844372272, 0.010005312971770763, -0.03079899773001671, -0.03043278679251671, -0.039294034242630005, -0.03863074257969856, -0.02043052762746811, 0.007639158051460981, 0.007180702406913042, 0.023951256647706032, 0.0038388094399124384, 0.018874043598771095, -0.01711714267730713, 0.030629916116595268, -0.039307378232479095, 0.028075218200683594, 0.013258863240480423, -0.028496738523244858, 0.03813563287258148, -0.007864776067435741, -0.015036817640066147, -0.010842532850801945, 0.029854604974389076, ...],
        ...
      ]
    >
  }
}
{x, _new_key} = Nx.Random.normal(key, shape: {2,3,768})
result = predict_fn.(params, x)
#Nx.Tensor<
  f32[2][3][768]
  EXLA.Backend
  [
    [
      [-0.3588794767856598, 0.4945526123046875, 1.1270363330841064, 0.47813189029693604, 0.11433834582567215, 0.3250342607498169, -0.18536220490932465, -0.4268503189086914, 0.0366511344909668, -0.5006465315818787, -0.3246268033981323, 0.469343364238739, -0.20338745415210724, 0.262503981590271, -0.2284332811832428, -0.39896360039711, -0.6442403793334961, 0.5278472304344177, 0.3479968309402466, 0.4037986397743225, 0.4400331974029541, -1.0023019313812256, -0.7732177376747131, 0.4238455891609192, 0.08402413129806519, -0.24891048669815063, 0.10037177801132202, -0.3519536852836609, -0.14434993267059326, 0.5883971452713013, 0.27222681045532227, -0.27717745304107666, -0.4193039536476135, 0.06716635823249817, -0.2755497097969055, 0.38321346044540405, -0.9211329221725464, 0.050623536109924316, 0.6111773252487183, 0.6410372853279114, -0.7234801054000854, 0.3030005395412445, 0.13389870524406433, 0.6142141222953796, -0.47253233194351196, -0.4482526183128357, -0.346852570772171, 0.5142409801483154, -0.5943482518196106, -0.781970739364624, ...],
      ...
    ],
    ...
  ]
>
Nx.shape(result)
{2, 3, 768}

The FeedForward module plays a crucial role in enhancing the model’s ability to learn from and generalize the data. Although the input and output dimensions of this module are the same, it internally expands the embedding dimension into a higher dimensional space through the first linear layer.

This expansion is followed by a nonlinear GELU activation and then a contraction back to the original dimension with the second linear transformation. Such a design allows for the exploration of a richer representation space.

shortcut connections that we insert between different layers of a neural network, which are important for improving the training performance in deep neural network architectures.

4.4 Adding shortcut connections

The shortcut connections were proposed for deep networks in computer vision (specifically, in residual networks) to mitigate the challenge of vanishing gradients. The vanishing gradient problem refers to the issue where gradients (which guide weight updates during training) become progressively smaller as they propagate backward through the layers, making it difficult to effectively train earlier layers.

defmodule LLM.Layer do
  def shortcut(x, layer_impl, opts \\ []) when is_function(layer_impl) do
    with {:arity, arity} <- Function.info(layer_impl, :arity),
          layer_output <- execute_layer(x, layer_impl, opts, arity),
          output <- shortcut_impl(x, layer_output, opts) do
      output
    end    
  end

  defp execute_layer(x, layer_impl, _opts, 1), do: layer_impl.(x)
  defp execute_layer(x, layer_impl, opts, _arity), do: layer_impl.(x, opts)
  defp shortcut_impl(x, layer_output, opts) do
    use_shortcut? = Keyword.get(opts, :use_shortcut, false)
    if use_shortcut?, 
      do: Axon.add(x, layer_output),
      else: layer_output
  end
end
{:module, LLM.Layer, <<70, 79, 82, 49, 0, 0, 10, ...>>, {:shortcut_impl, 3}}
model =
  Axon.input("input", shape: {nil, 2})
  #|>Axon.dense(2)
  |> LLM.Layer.shortcut(&amp;Axon.dense(&amp;1, 2, use_bias: false))

model_with_shortcut =
  Axon.input("input", shape: {nil, 2})
  |> LLM.Layer.shortcut(
    &amp;Axon.dense(&amp;1, 2, use_bias: false),
    use_shortcut: true
  )
{_init_fn, predict_with_short_fn} = Axon.build(model_with_shortcut)
{init_fn, predict_fn} = Axon.build(model)
template = Nx.template({1, 2}, :f32)

params = init_fn.(template, %{})
%{
  "dense_0" => %{
    "kernel" => #Nx.Tensor<
      f32[2][2]
      [
        [0.339205265045166, 0.5165625810623169],
        [-1.1745550632476807, -0.8592101335525513]
      ]
    >
  }
}
predict_fn.(params, Nx.tensor([[103, 103]]))
#Nx.Tensor<
  f32[1][2]
  [
    [-86.04103088378906, -35.29269790649414]
  ]
>
predict_with_short_fn.(params, Nx.tensor([[103, 103]]))
#Nx.Tensor<
  f32[1][2]
  [
    [16.958969116210938, 67.70730590820312]
  ]
>

In conclusion, shortcut connections are important for overcoming the limitations posed by the vanishing gradient problem in deep neural networks. Shortcut connections are a core building block of very large models such as LLMs.

4.5 Connecting attention and linear layers in a transformer block

defmodule Transformer.Layers do
  import Nx.Defn
  
  def attention(%Axon{} = input, opts \\ []) do
    #opts = Keyword.validate!(opts, [:name, :d_in, :d_out, :num_heads])
    head_dim = div(opts[:d_out], opts[:num_heads])
    w_query = Axon.param("w_query", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_key = Axon.param("w_key", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_value = Axon.param("w_value", fn _ -> {opts[:d_in], opts[:d_out]} end)
    out_proj = Axon.param("out_proj", fn _ -> {opts[:d_out], opts[:d_out]} end)

    Axon.layer(
      &amp;attention_impl/6,
      [input, w_query, w_key, w_value, out_proj],
      name: opts[:name],
      op_name: :causal_attention,
      head_dim: head_dim,
      num_heads: opts[:num_heads]
    )
  end

  #defnp attention_impl(input, w_query, w_key, w_value, head_dim, num_heads, _opts \\ []) do
  defnp attention_impl(input, w_query, w_key, w_value, out_proj, opts \\ []) do
    {b, num_tokens, _d_in} = Nx.shape(input)
    keys = Nx.dot(input, w_key)
    queries = Nx.dot(input, w_query)
    values = Nx.dot(input, w_value)
    d_k = Nx.axis_size(keys, -1)

    keys_reshaped = 
      keys
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    queries_reshaped = 
      queries
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    values_reshaped = 
      values
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])

    attn_score =
      keys_reshaped
      |> Nx.transpose(axes: [0, 1, 3, 2])
      |> then(&amp;Nx.dot(queries_reshaped, [3], [0, 1], &amp;1, [2], [0, 1]))

    simple_mask =
      attn_score
      |> then(&amp;Nx.broadcast(Nx.Constants.infinity(), &amp;1))
      |> Nx.triu(k: 1)

    masked = Nx.multiply(simple_mask, -1) |> Nx.add(attn_score)

    attn_weights =
      masked
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)
    
    context_vec =
      attn_weights
      |> Nx.dot([3], [0, 1], values_reshaped, [2], [0, 1])
      |> Nx.transpose(axes: [0, 2, 1, 3])

    context_vec
    |> Nx.reshape({b, num_tokens, opts[:num_heads] * opts[:head_dim]})
    |> Nx.dot(out_proj)
  end

  def shortcut(x, layer_impl, opts \\ []) when is_function(layer_impl) do
    with {:arity, arity} <- Function.info(layer_impl, :arity),
          layer_output <- execute_layer(x, layer_impl, opts, arity),
          output <- shortcut_impl(x, layer_output, opts) do
      output
    end    
  end

  defp execute_layer(x, layer_impl, _opts, 1), do: layer_impl.(x)
  defp execute_layer(x, layer_impl, opts, _arity), do: layer_impl.(x, opts)
  defp shortcut_impl(x, layer_output, opts) do
    use_shortcut? = Keyword.get(opts, :use_shortcut, false)
    if use_shortcut?, 
      do: Axon.add(x, layer_output),
      else: layer_output
  end

  def normalization(%Axon{} = input, opts \\ []) do
    #opts = Keyword.validate!(opts, [:name, :eps, :emb_dim])
    eps = Keyword.get(opts, :eps, 1.00e-5)
    scale = Axon.param("scale", {opts[:emb_dim]}, initializer: &amp;ones(&amp;1, type: &amp;2))
    shift = Axon.param("shift", {opts[:emb_dim]}, initializer: &amp;zeros(&amp;1, type: &amp;2))

    Axon.layer(
      &amp;normalization_impl/4,
      [input, scale, shift],
      name: opts[:name],
      op_name: :normalization,
      eps: eps
    )
  end

  defp ones(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(1)
  end

  defp zeros(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(0)
  end

  defnp normalization_impl(input, scale, shift, opts \\ []) do
    mean = Nx.mean(input, axes: [-1], keep_axes: true)
    variance = Nx.variance(input, axes: [-1], keep_axes: true)
    denominator = variance |> Nx.add(opts[:eps]) |> Nx.sqrt()

    input
    |> Nx.subtract(mean)
    |> Nx.divide(denominator)
    |> Nx.multiply(scale)
    |> Nx.add(shift)
  end

  def pos_embedding(%Axon{} = x, vocab_size, embedding_size, opts \\ []) do
    opts = Keyword.validate!(opts, [:name, kernel_initializer: :uniform])

    kernel_shape = &amp;Axon.Shape.embedding_kernel(&amp;1, vocab_size, embedding_size)

    kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

    Axon.layer(&amp;pos_embedding_impl/3, [x, kernel], name: opts[:name], op_name: :pos_embedding)
  end

  defnp pos_embedding_impl(x, kernel, _opts \\ []) do
    {_batch_size, sequence_size} = Nx.shape(x)
    input = Nx.iota({1, sequence_size})
    Nx.take(kernel, input, axis: 0)
  end

  def feedforward(input, emb_dim) do
    input
    |> Axon.dense(4*emb_dim)
    |> Axon.activation(:gelu)
    |> Axon.dense(emb_dim)
  end

  def feedforward_block(input, opts) do
    input
    |> normalization(opts)
    |> feedforward(opts[:emb_dim])
    |> Axon.dropout(rate: opts[:drop_rate])
  end

  def attention_block(input, opts) do
    input
    |> normalization(opts)
    |> attention(d_in: opts[:emb_dim], d_out: opts[:emb_dim], num_heads: opts[:n_heads] )
    |> Axon.dropout(rate: opts[:drop_rate])
  end

  def block(input, opts \\ []) do
    input
    |> shortcut(&amp;attention_block(&amp;1, opts), use_shortcut: true)
    |> shortcut(&amp;feedforward_block(&amp;1, opts), use_shortcut: true)
  end
end
{:module, Transformer.Layers, <<70, 79, 82, 49, 0, 0, 51, ...>>, {:block, 2}}
gpt_config_124m
[
  attn_name: "attention_0",
  vocab_size: 50257,
  context_length: 1024,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.1,
  qkv_bias: false
]
model = 
  Axon.input("sequence", shape: {2,4,768})
  |> Transformer.Layers.block(gpt_config_124m)

{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
template = Nx.template({2,4,768}, :f32)

params = init_fn.(template, %{})
%{
  "causal_attention_0" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04731859266757965, -0.014130741357803345, -0.009112060070037842, 0.03826461732387543, -0.010805293917655945, -0.05711685121059418, 0.01091797649860382, -0.031060680747032166, 0.04046545922756195, -0.04703208804130554, -0.04698623716831207, 0.0055214762687683105, -0.034206077456474304, -0.056184321641922, 0.04720066487789154, -0.043082237243652344, 0.03360249102115631, 0.04625195264816284, 0.03473791480064392, -0.036436647176742554, 0.03620511293411255, 0.045895546674728394, 0.041524723172187805, 0.029449328780174255, 0.03259524703025818, -0.02926628291606903, -0.056395918130874634, -0.05944550037384033, 0.002063557505607605, 0.02704685926437378, -0.03394484519958496, -0.026301205158233643, -0.02732875943183899, 0.05359862744808197, -0.0076325684785842896, -0.0017324090003967285, -0.015294089913368225, -0.0017408281564712524, -0.04911385476589203, 0.02864709496498108, -0.03101527690887451, -0.02927972376346588, -0.05796726047992706, 0.04391254484653473, 0.0523596853017807, 0.05369104444980621, 0.05769248306751251, -0.011591032147407532, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.044453561305999756, -0.05252973735332489, 0.035728275775909424, 0.06194005906581879, 0.036992162466049194, 0.057376354932785034, -0.043053656816482544, 0.05265139043331146, -0.049041494727134705, 0.018571794033050537, -0.010026589035987854, 0.011633649468421936, -0.030856207013130188, -0.03214313089847565, -0.009155750274658203, -1.2068450450897217e-4, 0.03872282803058624, -0.05158029496669769, -0.05190703272819519, 0.030057981610298157, 0.028335705399513245, 0.03558604419231415, -0.01684112846851349, -0.0020744353532791138, -0.016270503401756287, -0.017514541745185852, -0.020688414573669434, -0.048657163977622986, 0.01983901858329773, -0.02100050449371338, 0.017134085297584534, 0.051143184304237366, -0.05603599548339844, -0.025614887475967407, -0.04567280411720276, -2.9605627059936523e-4, -0.010439634323120117, 0.0014750510454177856, -6.760656833648682e-5, 0.002023860812187195, 0.058323830366134644, 0.057505831122398376, -0.008760780096054077, -0.03130887448787689, -0.06186647713184357, -0.02674366533756256, -0.054854631423950195, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.032608285546302795, -0.006642535328865051, 0.05297805368900299, 0.051614969968795776, -0.04018060863018036, -0.048893630504608154, 0.05452093482017517, -0.01947285234928131, 0.012484252452850342, 0.02486307919025421, 0.04290860891342163, 0.03901347517967224, 0.03237451612949371, -0.04737253487110138, 0.013066917657852173, 0.02978375554084778, -0.006310060620307922, -0.03749847412109375, -0.01044514775276184, -0.02960534393787384, -0.0010516047477722168, 0.0018389075994491577, 0.02699996531009674, 0.06097060441970825, 0.0030221790075302124, 0.01736435294151306, 0.04897160828113556, 0.0030396729707717896, -0.04924766719341278, 0.0389840304851532, -0.01835605502128601, 1.5456974506378174e-4, -0.014370813965797424, -0.06103244423866272, 0.05079524219036102, 0.020794183015823364, 0.012638792395591736, -0.011560514569282532, 0.03376096487045288, -0.04083840548992157, -0.00391581654548645, 0.04412230849266052, 0.04251180589199066, 0.01743115484714508, 0.05216251313686371, -0.017336532473564148, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04812884330749512, -0.038934752345085144, 0.010806962847709656, 0.032767683267593384, 0.003754749894142151, 0.025674015283584595, -0.027109459042549133, 0.016650453209877014, 0.00821453332901001, 0.04841548204421997, 0.05328567326068878, 0.044812679290771484, 0.015143215656280518, 0.015147581696510315, 0.03984843194484711, 0.009698674082756042, 0.017130225896835327, 0.04548853635787964, 0.0295943021774292, -0.023897260427474976, -0.02607196569442749, -0.021637991070747375, -0.05602104961872101, -0.03560064733028412, 3.744959831237793e-4, -0.014491453766822815, 0.057820677757263184, 0.008085653185844421, -0.007160007953643799, 0.0012309104204177856, 0.00835922360420227, 0.04568144679069519, -0.0435107946395874, -0.052208930253982544, -0.04845665395259857, -0.018066316843032837, -0.0022654682397842407, -0.032691359519958496, 0.04065173864364624, -0.020436882972717285, -0.01853884756565094, -0.05772070586681366, -0.023399829864501953, -0.02452811598777771, -0.025088071823120117, ...],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.004091805312782526, -0.007560521364212036, 0.026906177401542664, 0.030246015638113022, 0.012771832756698132, -0.008682779036462307, -0.02972627431154251, 0.026836737990379333, -0.03631436452269554, 0.0018882290460169315, 0.006806651130318642, -0.037564776837825775, 0.037733763456344604, -0.03865443542599678, -0.0193281639367342, 0.02252115309238434, 0.015335503034293652, -0.03226977586746216, -0.015374001115560532, 0.03285115212202072, 0.02087392285466194, -0.008560998365283012, 0.029197804629802704, -0.012814251706004143, 0.017765413969755173, -0.037728700786828995, 0.035684000700712204, 0.014469661749899387, -0.021675923839211464, 0.022656315937638283, 0.035955410450696945, -0.018395589664578438, -0.036058731377124786, 0.03221835568547249, 0.0063669877126812935, -0.013371963985264301, -0.00865225400775671, -0.026589293032884598, 0.00307067041285336, 0.03417544811964035, -0.017079776152968407, 0.032005008310079575, -0.03716343268752098, -0.003272369969636202, -0.022651387378573418, -0.002475807210430503, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.017011485993862152, -0.02012261562049389, -0.0344015397131443, 0.03524329885840416, -0.01521710492670536, 0.02654646895825863, 0.017354531213641167, -0.017773989588022232, -0.017466548830270767, -0.011121360585093498, 0.03750540316104889, 0.026669740676879883, -0.00589218083769083, 0.013468384742736816, 0.01918012648820877, 0.0056791347451508045, 0.03339534252882004, 0.013178341090679169, -0.023255571722984314, 8.087482419796288e-4, 0.03004595637321472, 0.008573560975492, -0.0373084619641304, 0.0368356816470623, -0.022242382168769836, 0.02989266999065876, 0.01815837062895298, 0.001188341062515974, -0.01859203912317753, 0.024238387122750282, 0.004938825499266386, -0.01812833547592163, 0.014285039156675339, 0.002843261696398258, -0.008637165650725365, -0.006106555927544832, 0.02334817498922348, -0.003194204531610012, -0.0011421996168792248, -0.01179812103509903, 0.01505579799413681, 0.0239674374461174, -0.019915422424674034, -0.0017682573525235057, -0.02840084582567215, ...],
        ...
      ]
    >
  },
  "normalization_0" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_1" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  }
}
Axon.Display.as_graph(model, template)
key = Nx.Random.key(103)
{input, _new_key} = Nx.Random.normal(key, shape: {2,4,768})
result = predict_fn.(params, input)
#Nx.Tensor<
  f32[2][4][768]
  EXLA.Backend
  [
    [
      [-0.716768205165863, -0.4204854965209961, -1.8959286212921143, 0.37171489000320435, -0.20300930738449097, 2.724903106689453, -0.4133290648460388, -0.30961093306541443, 0.24687054753303528, 1.410226583480835, -1.4018677473068237, -1.758493423461914, -1.3413493633270264, -0.46528956294059753, 0.24200332164764404, 0.9852171540260315, -0.42596209049224854, 1.9925718307495117, 0.5250499248504639, 0.35458439588546753, 1.0019404888153076, 0.18823081254959106, -0.13523443043231964, -1.5909684896469116, 1.1885206699371338, 1.9854960441589355, 2.420884370803833, -2.179537534713745, 2.465571403503418, 0.9325289726257324, -1.3916890621185303, 1.860776662826538, 0.06037481874227524, 1.2924396991729736, 1.666346788406372, -0.714210033416748, 0.5654738545417786, 1.7855122089385986, -1.270380973815918, 0.11856108903884888, 2.654468297958374, 0.7143635749816895, 1.0451276302337646, 0.23774230480194092, -2.40054988861084, 1.6369502544403076, 0.9113302230834961, -1.1185961961746216, 0.11834041774272919, 1.1322273015975952, ...],
      ...
    ],
    ...
  ]
>

The Transformer.Layers.block includes a multi-head attention mechanism (MultiHeadAttention) and a feed forward network (Feed-Forward), both configured based on a provided configuration dictionary (cfg), such as GPT_CONFIG_124M.

Layer normalization (LayerNorm) is applied before each of these two components, and dropout is applied after them to regularize the model and prevent overfitting. This is also known as Pre-LayerNorm. Older architectures, such as the original transformer model, applied layer normalization after the self-attention and feed forward networks instead, known as Post-LayerNorm, which often leads to worse training dynamics.

The module also implements the forward pass, where each component is followed by a shortcut connection that adds the input of the block to its output. This critical feature helps gradients flow through the network during training and improves the learning of deep models.

The transformer block maintains the input dimensions in its output, indicating that the transformer architecture processes sequences of data without altering their shape throughout the network.

The preservation of shape throughout the transformer block architecture is not incidental but a crucial aspect of its design. This design enables its effective applica- tion across a wide range of sequence-to-sequence tasks, where each output vector directly corresponds to an input vector, maintaining a one-to-one relationship. However, the output is a context vector that encapsulates information from the entire input sequence. This means that while the physical dimensions of the sequence (length and feature size) remain unchanged as it passes through the trans- former block, the content of each output vector is re-encoded to integrate contextual information from across the entire input sequence.

4.6 Coding the GPT model

The output from the final transformer block then goes through a final layer normal- ization step before reaching the linear output layer. This layer maps the transformer’s output to a high-dimensional space (in this case, 50,257 dimensions, corresponding to the model’s vocabulary size) to predict the next token in the sequence.

# Test to dynamically layers.
model =
  Axon.input("sequence", shape: {1, 2})

new_model =
  for _i <- 1..3, reduce: model do
    model_acc ->
      Axon.dense(model_acc, 6)
  end

template = Nx.template({1, 2}, :f32)
Axon.Display.as_graph(new_model, template)
defmodule MyGPT do
  @gpt_config_124m gpt_config_124m
  def model(input_shape \\ {2, 4, 768}, opts \\ @gpt_config_124m) do
    Axon.input("sequence", shape: input_shape)
    |> embedding_block(opts)
    |> Axon.dropout(rate: opts[:drop_rate])
    |> transformer_blocks(12, opts)
    |> Transformer.Layers.normalization(opts)
    |> Axon.dense(opts[:vocab_size], use_bias: false)
  end

  def embedding_block(input, opts) do
    token_emb = Axon.embedding(input, opts[:vocab_size], opts[:emb_dim])
    pos_emb = Transformer.Layers.pos_embedding(input, opts[:context_length], opts[:emb_dim])

    Axon.add(token_emb, pos_emb)
  end

  def transformer_blocks(input, n_blocks, transformer_opts) do
    for _n_block <- 1..n_blocks, reduce: input do
      model_acc ->
        Transformer.Layers.block(model_acc, transformer_opts)
    end
  end
end
{:module, MyGPT, <<70, 79, 82, 49, 0, 0, 15, ...>>, {:transformer_blocks, 3}}
batch
#Nx.Tensor<
  s64[2][4]
  [
    [6109, 3629, 6100, 345],
    [6109, 1110, 6622, 257]
  ]
>
template = Nx.template({1, 4}, :s64)
model = MyGPT.model()
#Axon<
  inputs: %{"sequence" => {2, 4, 768}}
  outputs: "dense_24"
  nodes: 152
>
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
params = init_fn.(template, %{})
%{
  "pos_embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[1024][768]
      EXLA.Backend
      [
        [0.00820946879684925, 0.009188241325318813, -0.0034671758767217398, -0.007448150776326656, -0.007996559143066406, -0.00552819250151515, -0.008458909578621387, 0.0015659808414056897, -0.009515397250652313, -0.0023298191372305155, -0.0063433838076889515, 0.007155895233154297, -0.007909640669822693, 0.009121849201619625, -0.0038319635204970837, 0.00977078452706337, 0.0013030886184424162, -9.236144833266735e-4, -0.0017825793474912643, 0.008658253587782383, 0.004096589051187038, -0.009104898199439049, 0.00883461907505989, 2.1013259538449347e-4, -0.0039931414648890495, 0.0012662267545238137, -0.00589345907792449, -1.7762422794476151e-4, 0.007725028786808252, 0.00836909469217062, -0.0021765350829809904, 0.0031562065705657005, 0.0010756277479231358, -0.0020445871632546186, -8.003711627679877e-6, 0.008890127763152122, 0.007651355117559433, 0.0061145662330091, -0.004479820840060711, 0.00866932887583971, -0.008535603992640972, 0.005303759593516588, -0.0029913424514234066, -0.008174783550202847, 0.006251899991184473, 0.008821689523756504, -0.00166055909357965, -0.005631952080875635, ...],
        ...
      ]
    >
  },
  "normalization_6" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_4" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.007951725274324417, -0.008591024205088615, 0.03084629774093628, 0.03919801861047745, -0.006465104408562183, 0.03587324172258377, -0.029273860156536102, 0.0013427963713184, 2.444763667881489e-4, -0.024390175938606262, -0.034954559057950974, -0.0323372446000576, 0.0014753305586054921, 0.007874681614339352, -0.031473711133003235, -0.030932558700442314, -0.037745844572782516, 0.01325751468539238, -0.0017783696530386806, 0.027555551379919052, -0.025321099907159805, -0.038646772503852844, 0.005787608679383993, 0.0187541376799345, 0.02014927752315998, -0.03415210545063019, -0.013058982789516449, 0.0022036044392734766, 5.003655678592622e-4, -0.0221744142472744, -0.002541174413636327, -0.03399861976504326, 0.016714544966816902, -0.011524014174938202, -0.012131176888942719, -0.038383882492780685, 0.016222171485424042, 0.03854641318321228, -0.03233107179403305, -0.036432698369026184, -0.03398247808218002, -0.025048643350601196, 0.02434637024998665, 0.03237614035606384, -0.004860387183725834, ...],
        ...
      ]
    >
  },
  "normalization_5" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_20" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.008643423207104206, -0.03808677941560745, 0.03856589272618294, 0.0289661455899477, -0.039159614592790604, 0.009961207397282124, -0.03921467438340187, 0.00720088928937912, 0.023042816668748856, 0.0010641661938279867, -0.019540956243872643, -0.039042893797159195, -0.0366806797683239, 0.0011096857488155365, 0.003740136744454503, -0.030178753659129143, -0.028409432619810104, -0.014573565684258938, -0.03134012222290039, 0.01033230870962143, -0.038222573697566986, -0.02392405830323696, -0.029808038845658302, 0.02350245974957943, 0.004543343558907509, -0.020840145647525787, -0.0012748375302180648, 0.036794018000364304, -0.010731796734035015, -0.03439362347126007, -0.03502519428730011, 0.0040095215663313866, -0.02168966457247734, -0.026715749874711037, 0.004904756788164377, -0.028651166707277298, 0.011725619435310364, 0.0063535673543810844, 0.005658721551299095, -0.0079327542334795, 0.016678035259246826, -0.012191633693873882, -0.0018477326957508922, ...],
        ...
      ]
    >
  },
  "normalization_1" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_9" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_22" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.005695730913430452, 0.01940457709133625, -0.016799259930849075, -0.009635455906391144, 0.004453614354133606, 0.02737174741923809, 0.007082218304276466, 0.0016170063754543662, 0.03210758417844772, 0.028956137597560883, -0.02683628723025322, -0.014833912253379822, -0.0140641238540411, 0.01708773896098137, -0.03735227510333061, -0.028188431635499, 0.014741497114300728, -0.027338745072484016, 0.007070861756801605, 0.03213026002049446, -0.030906114727258682, -0.034542836248874664, -0.037251945585012436, -0.0022951713763177395, -0.03173783794045448, 0.0034601683728396893, 0.0025511831045150757, 0.023164816200733185, -0.0019491383573040366, -0.02549983188509941, -0.016073890030384064, 0.026647498831152916, 0.0013607025612145662, 0.039376188069581985, 0.025203324854373932, 0.014966342598199844, 0.036406517028808594, -0.02063322626054287, -0.02118971385061741, 0.03581114485859871, ...],
        ...
      ]
    >
  },
  "causal_attention_0" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05489400029182434, -0.039155200123786926, -0.009170278906822205, -0.021611839532852173, -0.01464468240737915, 0.021818622946739197, 0.05251871049404144, -0.05430804193019867, 0.0021059811115264893, -0.017595037817955017, -0.05499380826950073, 0.033547788858413696, -0.044628411531448364, -0.0522761195898056, 0.03932121396064758, -0.01695135235786438, -0.020284101366996765, -0.0028049200773239136, -0.013862058520317078, 0.0012287497520446777, 0.05986718833446503, -0.010002121329307556, 0.029883012175559998, 0.03109125792980194, -0.029560640454292297, -0.04171241819858551, -0.019969791173934937, 0.024976268410682678, -0.005772814154624939, 0.05854707956314087, 0.04195404052734375, 0.039050742983818054, 0.05545467138290405, -0.05645453929901123, 0.01080729067325592, -0.05919128656387329, 0.06138846278190613, 0.04838235676288605, 0.020340412855148315, -0.013968273997306824, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04687102138996124, -0.03536757826805115, -0.007010817527770996, 0.042143821716308594, 0.06117507815361023, 0.009219691157341003, 0.050558462738990784, 0.007298022508621216, -9.268522262573242e-4, -0.003967806696891785, 0.015224486589431763, 0.03371243178844452, 0.05220791697502136, -0.03269970417022705, 0.05720898509025574, 0.04330742359161377, 0.0027159452438354492, -0.007316946983337402, -0.036986514925956726, 0.03676225244998932, 0.0434393435716629, 0.032708942890167236, -0.019815340638160706, -0.010256662964820862, -0.027735784649848938, -0.0401027649641037, 0.021295860409736633, 0.05721485614776611, 0.013380691409111023, -0.04448729753494263, 0.031398698687553406, 0.019680440425872803, 0.014919206500053406, 0.05950406193733215, -0.05872026085853577, 0.03292544186115265, 0.015416577458381653, 0.052167996764183044, 0.006425082683563232, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.022899866104125977, -0.033817097544670105, 0.059301987290382385, -0.04439030587673187, 0.012292608618736267, 0.039279863238334656, -0.05003109574317932, -0.02336028218269348, 0.04098793864250183, 0.03009355068206787, 0.012847885489463806, 0.042464762926101685, 9.96425747871399e-4, -0.05092477798461914, -0.045868054032325745, -0.012940824031829834, -0.03968289494514465, -0.036278724670410156, -0.03340412676334381, 0.024561360478401184, -0.013194680213928223, 0.03875914216041565, 0.059637561440467834, -0.03777608275413513, -0.017550170421600342, -0.055209681391716, 0.021982118487358093, 0.036332473158836365, 0.020841673016548157, 0.0036048144102096558, -0.026743724942207336, 0.029593050479888916, -0.03466331958770752, 0.05739566683769226, 0.03717690706253052, 0.05089336633682251, -0.002735033631324768, 0.0447123646736145, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.02619999647140503, 0.048979252576828, 3.127157688140869e-4, -0.028897106647491455, 0.03259924054145813, -0.038410648703575134, 0.029967039823532104, 5.020499229431152e-4, 0.05025339126586914, 0.001332402229309082, 0.014239713549613953, -0.06118769943714142, -0.005514770746231079, 0.06032727658748627, 0.014304712414741516, -0.029770493507385254, -0.03880971670150757, 0.06235247850418091, 0.054254576563835144, -0.042912185192108154, -0.042275190353393555, 0.04345338046550751, 0.04867301881313324, -0.014384850859642029, 0.006365090608596802, -0.039435938000679016, -0.0261828750371933, 0.00685197114944458, -0.014525637030601501, 0.025676921010017395, 0.05117411911487579, -0.03469167649745941, -0.026622265577316284, -0.04466487467288971, 0.04095759987831116, -0.05895838141441345, -0.024648264050483704, ...],
        ...
      ]
    >
  },
  "normalization_14" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_7" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_11" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_2" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_12" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.030212607234716415, 0.01017819344997406, 0.023311730474233627, 0.0391693115234375, -0.009699258022010326, -0.0331733338534832, 0.024538900703191757, 0.007583545055240393, 0.01462976261973381, -0.014602328650653362, -0.0029376179445534945, -0.007452735677361488, -0.013505741953849792, 0.027952220290899277, -0.009736993350088596, 0.03601047024130821, 0.023169782012701035, 0.01591755822300911, -0.015054394491016865, 0.015612625516951084, -0.03216945379972458, 0.02113960310816765, -0.006786813028156757, -0.03232434391975403, -0.008756335824728012, -0.007568060886114836, 0.03484826907515526, -0.0016728172777220607, 0.03853771463036537, 0.017834974452853203, 0.03739023581147194, 0.006592803634703159, -0.012812159024178982, -0.026464100927114487, ...],
        ...
      ]
    >
  },
  "dense_23" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.020444191992282867, -0.027403226122260094, 0.009465120732784271, 0.011801626533269882, -0.03225001320242882, -0.020963340997695923, -0.03514442965388298, -0.007892417721450329, 0.02807697094976902, -0.01127086766064167, 0.028419196605682373, 0.021006353199481964, -0.005180484149605036, 0.004865542054176331, -0.03541543707251549, 0.02198111079633236, 0.016260679811239243, 0.01586957834661007, 0.033852458000183105, -0.013914522714912891, -0.0057561215944588184, 0.017434582114219666, -0.03442095220088959, 0.03288452327251434, 0.03143324330449104, 0.0015610731206834316, 0.03612896054983139, 0.03294341638684273, 0.006717195268720388, 0.02621821127831936, -0.03397725522518158, 0.028073681518435478, -0.03619467839598656, ...],
        ...
      ]
    >
  },
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[50257][768]
      EXLA.Backend
      [
        [-0.0016090749995782971, -0.003907220438122749, -0.00360292661935091, -0.00231545208953321, -0.005603322759270668, 0.0032405112870037556, -0.0034749649930745363, -0.008328196592628956, -0.008334300480782986, 0.002380151767283678, 0.0019280958222225308, -0.006717686541378498, -0.009055950678884983, 0.008918508887290955, -0.007226414512842894, -0.004065058194100857, 0.009506313130259514, -0.008293678984045982, 8.306717500090599e-4, 0.006336223799735308, 0.004062929190695286, 0.004070937633514404, -2.814316831063479e-4, 9.565543732605875e-4, -0.00485666748136282, 0.001047842437401414, -0.00563220027834177, 0.003163895569741726, -0.007921133190393448, -0.008668815717101097, -0.00760646304115653, -0.006582675036042929, -0.004447436425834894, ...],
        ...
      ]
    >
  },
  "normalization_19" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_6" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.025701820850372314, -0.054504990577697754, 0.029997289180755615, -0.03686901926994324, 0.016721084713935852, -0.03375677764415741, 0.059069350361824036, 7.250607013702393e-4, -0.02984979748725891, 4.918873310089111e-4, -7.946044206619263e-4, 0.056285277009010315, 0.010786309838294983, 0.05521267652511597, 0.04236045479774475, -0.016745954751968384, -0.058286041021347046, -0.05879490077495575, 0.023886770009994507, -0.016313806176185608, 0.017728060483932495, 0.0613899827003479, 0.0061877816915512085, 0.005800962448120117, 0.01997123658657074, -0.03392757475376129, 0.02042210102081299, 0.019898176193237305, -0.048953741788864136, 0.030590444803237915, 0.01393859088420868, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.010096222162246704, 0.0012035369873046875, 0.017482370138168335, -0.054907336831092834, -0.034396275877952576, -0.031409040093421936, 0.03767158091068268, 0.00809420645236969, 0.047381192445755005, 0.02798573672771454, -0.040443405508995056, -0.05414016544818878, 0.005056798458099365, 0.03312613070011139, -0.030661195516586304, 0.007245868444442749, 0.0082969069480896, 0.018908783793449402, -0.021926403045654297, -0.0109795480966568, 0.031210750341415405, -0.033631905913352966, 0.019452661275863647, 0.009734466671943665, 0.04139360785484314, 0.008643284440040588, 0.03843680024147034, 0.030701294541358948, 0.027932196855545044, 0.054621800780296326, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.03707754611968994, -0.030656009912490845, 0.03520679473876953, -0.04625687003135681, -0.0189407616853714, 0.02809925377368927, 0.03452543914318085, -0.04839430749416351, 0.05323837697505951, 0.0021147727966308594, 0.02971041202545166, -0.025895893573760986, 0.027443334460258484, -0.03392206132411957, -0.04972764849662781, -0.0307915061712265, 0.04374489188194275, 0.01891949772834778, -0.00454828143119812, -0.04370458424091339, 0.012373551726341248, -0.01682738959789276, 0.005775421857833862, -0.0018092989921569824, -0.05827032029628754, -0.013664931058883667, 0.0270087867975235, -0.034229978919029236, -0.0029068589210510254, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.052588433027267456, 0.05754421651363373, -0.042503178119659424, -0.04482130706310272, 0.029898464679718018, -0.048212021589279175, 0.05257728695869446, 0.0042043328285217285, 0.04159720242023468, 0.062173157930374146, -0.025318071246147156, 0.024036988615989685, -0.051172688603401184, -0.03796112537384033, -0.016400858759880066, -0.03468772768974304, -0.045705899596214294, 0.0405057817697525, 0.059305816888809204, -0.04508274793624878, -0.05889318883419037, 0.004644140601158142, -0.06178523600101471, -0.05705609917640686, 0.0390387624502182, -0.03353504836559296, -0.026936382055282593, 0.047242820262908936, ...],
        ...
      ]
    >
  },
  "causal_attention_7" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.032045722007751465, 0.017456412315368652, -0.0242689847946167, 0.04675541818141937, 0.029787346720695496, -0.03577202558517456, -0.0022034794092178345, 0.03651021420955658, 0.02027672529220581, -0.02710777521133423, -0.00933249294757843, -0.03069135546684265, -0.057961657643318176, 0.02531592547893524, 0.022363126277923584, 0.04881356656551361, 0.04138445854187012, 0.010208293795585632, -0.008330345153808594, -0.041544973850250244, 0.015900418162345886, 0.006786465644836426, -0.020606979727745056, 0.0022400468587875366, 0.02052748203277588, -0.011774599552154541, -0.025713816285133362, -0.023647218942642212, 0.011612921953201294, -0.026715517044067383, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05895945429801941, -0.024802401661872864, 0.05039659142494202, 0.036231085658073425, 0.022191762924194336, -0.02228279411792755, -0.03547416627407074, -0.049903035163879395, 0.04348662495613098, -0.03179320693016052, -0.05164504051208496, 0.0015797168016433716, 0.06085902452468872, 0.05436211824417114, -0.022526323795318604, -0.038063883781433105, 0.027981683611869812, -0.010301768779754639, 0.022196829319000244, -0.0431160032749176, 0.024083688855171204, 0.04146072268486023, -0.009297430515289307, 0.026913568377494812, 0.040827974677085876, -0.01640072464942932, -0.027775421738624573, 0.01280316710472107, -0.015676349401474, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.003850281238555908, 0.030075520277023315, -0.0247572660446167, -0.03283834457397461, -0.0576925128698349, 0.007954820990562439, 0.05890782177448273, -0.01030498743057251, 0.025150269269943237, -0.04736118018627167, 0.009444251656532288, -0.026661813259124756, 0.03734239935874939, -0.03330172598361969, -0.012048572301864624, 0.01416216790676117, 0.0496334433555603, 0.009033679962158203, 0.04660935699939728, -0.027820080518722534, 0.03103475272655487, 0.018427371978759766, -0.02513173222541809, 0.031223684549331665, -0.035366833209991455, -3.833472728729248e-4, -0.005494505167007446, 0.022209420800209045, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.02350783348083496, -0.02003483474254608, 0.04679706692695618, -0.04162248969078064, 0.03920908272266388, 0.006552442908287048, -0.04930463433265686, 0.024353966116905212, -0.024234965443611145, -0.016075626015663147, 0.017792314291000366, -0.040183261036872864, -0.05149658024311066, 0.03500549495220184, -0.0168725848197937, -0.05759905278682709, -0.036677464842796326, 0.039839357137680054, 0.056501612067222595, 0.01898530125617981, -0.015023767948150635, -0.008932814002037048, 0.04588139057159424, -0.05752365291118622, -0.047592490911483765, 0.012481942772865295, -0.05152970552444458, ...],
        ...
      ]
    >
  },
  "causal_attention_1" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.012421980500221252, -0.04269014298915863, -0.04550692439079285, -0.016208991408348083, 0.020465701818466187, 0.05221274495124817, 0.005325481295585632, -0.057037726044654846, 0.036669909954071045, 0.008155956864356995, -0.04294180870056152, 0.02440929412841797, 0.011539056897163391, -0.05979380011558533, -0.01832081377506256, 0.016922026872634888, -0.034844666719436646, 0.03192244470119476, -0.0049813538789749146, -0.04656209051609039, -0.03621731698513031, 0.05114838480949402, -0.023377642035484314, 0.041796475648880005, -0.04053215682506561, -0.017181262373924255, -0.03711695969104767, 0.03578770160675049, -0.04483543336391449, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.027674600481987, -0.057753339409828186, 0.05609412491321564, 0.04927493631839752, 0.037359803915023804, -0.0138387531042099, 0.03773191571235657, 0.04090070724487305, -0.03581881523132324, 0.04968537390232086, 0.03941449522972107, 0.008177310228347778, -0.04164418578147888, -0.017723917961120605, -0.032029956579208374, -0.034827858209609985, 0.015351071953773499, 0.05072504281997681, 0.012002795934677124, -0.04956841468811035, -0.05421432852745056, -0.007948070764541626, 0.057788386940956116, -0.01360715925693512, 0.04243507981300354, -0.038228392601013184, -0.02971363067626953, -0.055006951093673706, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.02957746386528015, 0.03289474546909332, 0.05949847400188446, 0.010120987892150879, -0.01310354471206665, 0.004510313272476196, -0.015955954790115356, 0.06145139038562775, -0.014455825090408325, -0.05895638465881348, -0.026723697781562805, -0.04537688195705414, -0.04693858325481415, 0.007102593779563904, 0.012268990278244019, 0.02478422224521637, -0.0333857387304306, 0.0011792182922363281, -0.051901236176490784, 0.05027095973491669, -0.040482714772224426, 0.012476295232772827, 0.021930694580078125, -0.026617586612701416, 0.043517738580703735, -0.010764479637145996, 0.057726114988327026, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.019613072276115417, 0.051455721259117126, 0.03897149860858917, 0.007469534873962402, -0.0030021369457244873, 0.023965954780578613, -0.04122491180896759, 0.003679707646369934, 0.03143559396266937, 0.05365791916847229, -0.04962342977523804, 0.05203866958618164, -0.04326619207859039, -0.04479421675205231, 0.014022588729858398, -0.03217512369155884, 0.022234678268432617, -0.0034312456846237183, 0.0030553340911865234, -0.025083303451538086, -0.010746285319328308, 0.017215639352798462, 0.04503507912158966, -0.005856752395629883, -0.014508351683616638, 0.044021353125572205, ...],
        ...
      ]
    >
  },
  "dense_15" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.005395660176873207, 0.011839776299893856, -0.02832644246518612, 0.01453049574047327, -0.03146753087639809, -0.02135009691119194, 0.01870754361152649, -8.676031575305387e-5, -0.010343081317842007, 0.01302742026746273, 0.02643774077296257, 0.006407936103641987, -0.037305522710084915, -0.019868772476911545, -0.03678732365369797, -0.0359368734061718, 6.109505775384605e-4, 0.015419756062328815, 0.006048841401934624, 0.0058729080483317375, -0.02619294449687004, 0.012516112998127937, 0.013984186574816704, 0.035320863127708435, -0.010913054458796978, -0.006062120199203491, 0.015807265415787697, ...],
        ...
      ]
    >
  },
  "dense_9" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.012368207797408104, -0.03800829499959946, -0.0035753806587308645, 0.0033571417443454266, -0.014881703071296215, -0.009259303100407124, 0.01123113464564085, 0.018255770206451416, -0.03940468654036522, 0.017897212877869606, -0.003306599101051688, -0.0011907254811376333, -0.02017735317349434, -0.03826594352722168, 0.01460936851799488, -0.01519190426915884, 0.03060215152800083, -0.0058674984611570835, 0.03689923137426376, 0.0227354709059, 0.0329238623380661, 0.015255508944392204, 0.0027411491610109806, -0.032676443457603455, -0.01861805096268654, -0.03613306209445, ...],
        ...
      ]
    >
  },
  "normalization_8" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_10" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.022798985242843628, -0.045337170362472534, 0.025168761610984802, -0.04957088828086853, -8.123666048049927e-4, 0.051955416798591614, 0.05026403069496155, -0.0477260947227478, -0.03661969304084778, -0.032279565930366516, 0.02437390387058258, 0.026948392391204834, -0.023699268698692322, -0.022593483328819275, 0.03326416015625, 0.023770049214363098, 0.028966262936592102, 0.06071507930755615, 0.04856109619140625, 0.030310213565826416, 0.0391174852848053, 0.03669564425945282, 0.02858544886112213, -0.013668075203895569, 0.04215233027935028, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.03181988000869751, -0.04755681753158569, 0.03437581658363342, 0.053778424859046936, 0.05072203278541565, -0.04093801975250244, 0.058276742696762085, 0.02703551948070526, -0.02373218536376953, 0.05170321464538574, -0.020467787981033325, -0.012876242399215698, 0.038928136229515076, 0.028580963611602783, -0.035283222794532776, 0.03956948220729828, 0.020930737257003784, -0.0026922523975372314, 0.044234707951545715, 0.05327966809272766, -0.029353812336921692, 0.03557972609996796, 0.05551312863826752, -0.007851481437683105, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.0056618452072143555, 0.03151094913482666, 0.03827774524688721, 0.046215325593948364, 0.05178867280483246, 0.027159005403518677, 0.05923083424568176, -0.014196738600730896, -0.0031314492225646973, -0.04665765166282654, 0.042712077498435974, 0.03776612877845764, 0.030311763286590576, -0.053914621472358704, -0.050126731395721436, -0.0330914705991745, 0.003892362117767334, -0.004208564758300781, -0.056755974888801575, 0.018028929829597473, 0.014910101890563965, 0.050630778074264526, 0.0035226047039031982, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.002143353223800659, 0.06062677502632141, -0.006879165768623352, -0.0480877161026001, 0.039068520069122314, 0.011195480823516846, -0.00621400773525238, 0.028505608439445496, -0.01588299870491028, 0.0023775845766067505, -0.011070683598518372, 0.027731314301490784, -0.027953684329986572, 0.05115070939064026, -0.059734269976615906, -0.04799060523509979, 0.04361407458782196, 0.012062788009643555, 0.027975618839263916, -0.0035041719675064087, -0.04900294542312622, 0.05563957989215851, ...],
        ...
      ]
    >
  },
  "causal_attention_8" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.025582104921340942, 0.023829326033592224, 0.00909799337387085, -0.011847391724586487, -0.020098283886909485, 0.025455951690673828, -0.034765854477882385, -0.060432299971580505, -0.04105794429779053, 0.006512001156806946, -0.043706014752388, -0.05092586576938629, 0.023198038339614868, 0.02756880223751068, -0.033042341470718384, 0.015856102108955383, 0.008182674646377563, 0.03012464940547943, -0.03707456588745117, -0.01469673216342926, -0.03442901372909546, -0.0204780250787735, 0.014245688915252686, -0.01791292428970337, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05954828858375549, -0.018334031105041504, -0.053740084171295166, -0.04106515645980835, 0.04359300434589386, -0.052301377058029175, -0.05622832477092743, -0.011903241276741028, -0.020184949040412903, 0.030262917280197144, 0.016308650374412537, -0.041071027517318726, -0.009465128183364868, 0.043169617652893066, 0.052690088748931885, -0.02101065218448639, -0.02739103138446808, 0.033964574337005615, 0.011630713939666748, -0.012691468000411987, -0.026288464665412903, 0.06245419383049011, -0.06054665148258209, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05136667191982269, 0.037590816617012024, -0.04500292241573334, -0.034925028681755066, 0.008359581232070923, -0.011601179838180542, 0.046158358454704285, -0.03811800479888916, -0.03293466567993164, 0.047059565782547, 0.0218227356672287, -0.009044930338859558, 0.05121052265167236, -0.05056823790073395, 0.03798666596412659, 0.03153999149799347, 0.026914402842521667, 0.03440725803375244, -0.005188554525375366, 0.04151977598667145, -0.049881115555763245, 0.03422388434410095, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.02544592320919037, -0.03634065389633179, -0.018084004521369934, -0.04243233799934387, 0.04002797603607178, 2.098381519317627e-4, -2.989470958709717e-4, -0.057203203439712524, -0.03128752112388611, 0.038946688175201416, 0.060461074113845825, 0.035974353551864624, 0.003053322434425354, -0.0013038963079452515, -0.013969793915748596, -0.028209075331687927, -0.008393153548240662, 0.061665311455726624, 0.030390232801437378, 0.019174695014953613, 0.024785682559013367, ...],
        ...
      ]
    >
  },
  "normalization_12" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.0366412028670311, -0.03651485964655876, 0.02650594525039196, -0.016402289271354675, -0.025654854252934456, 0.015297504141926765, -0.019247323274612427, 0.012388564646244049, -0.0240774005651474, -0.01869748905301094, -0.005119819659739733, 0.02205471508204937, 0.012353327125310898, 0.03934800624847412, 0.03082224726676941, 0.029741965234279633, -0.010925635695457458, 0.03131837025284767, 0.03803667798638344, 0.01707005873322487, 0.029323609545826912, ...],
        ...
      ]
    >
  },
  "dense_18" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.022697916254401207, -0.01417359709739685, -0.03336319699883461, 0.015115143731236458, 0.015952371060848236, -0.01764185167849064, -0.021568890661001205, -0.01160652469843626, 0.007532644551247358, 0.013523111119866371, 0.018907152116298676, -0.023324472829699516, -0.025158587843179703, 0.036196835339069366, 0.02822321653366089, 0.0060785748064517975, -0.03328672796487808, 0.020602766424417496, -0.03298238664865494, -0.030606817454099655, ...],
        ...
      ]
    >
  },
  "dense_5" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.025756023824214935, -0.03596629574894905, -0.02284686639904976, 0.019689587876200676, 0.023602772504091263, 0.029920782893896103, -0.03038092516362667, -0.010648720897734165, -0.01590622030198574, 0.011320646852254868, 0.0022809687070548534, -0.02498069405555725, 0.022690394893288612, -0.03188413381576538, 0.03524188697338104, -0.01881878823041916, -0.03762465715408325, 0.03225862607359886, 0.02837536297738552, ...],
        ...
      ]
    >
  },
  "dense_11" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.0037141256034374237, 0.034073393791913986, 0.013978814706206322, 0.0030279310885816813, -0.010822675190865993, -0.009056679904460907, 0.007703478913754225, 0.02177872322499752, -0.02680893801152706, 0.010488017462193966, 0.028833188116550446, -0.03287756070494652, -0.03475680947303772, 0.008495103567838669, 0.007366380654275417, 0.017278751358389854, -0.007913537323474884, -0.015227839350700378, ...],
        ...
      ]
    >
  },
  "normalization_18" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_6" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.02989180199801922, 0.030130190774798393, -0.022942213341593742, 0.03949263319373131, -0.030400875955820084, -0.019850393757224083, 0.030135562643408775, -0.031825728714466095, -0.03247269243001938, 0.026158630847930908, -0.0020208009518682957, 0.0016321230214089155, -0.009290647692978382, -0.004441353492438793, -0.004950558766722679, 0.023847494274377823, ...],
        ...
      ]
    >
  },
  "normalization_15" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_23" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_10" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.0249734278768301, 0.0143804419785738, 0.009133638814091682, -0.013027703389525414, 0.02458631433546543, 0.026299260556697845, -1.9839141168631613e-4, -0.011157785542309284, -0.0375479981303215, -0.033345937728881836, 0.004981668666005135, 0.03158210217952728, -0.02271103486418724, ...],
        ...
      ]
    >
  },
  "dense_13" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.013052083551883698, -0.03876694291830063, -0.03504530340433121, 0.022804325446486473, -0.03920937702059746, -0.007452443242073059, 0.013561619445681572, -0.03840300440788269, 0.023222096264362335, -0.02473028004169464, 0.017051653936505318, 0.01958180032670498, ...],
        ...
      ]
    >
  },
  "dense_21" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.014589134603738785, 0.024606434628367424, 0.03452145308256149, -0.020177079364657402, -0.030781080946326256, -0.0051908413879573345, -0.022325005382299423, -0.02352965995669365, 0.021398339420557022, 0.007469284813851118, -0.031778354197740555, ...],
        ...
      ]
    >
  },
  "dense_8" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.03523549810051918, 0.022907031700015068, 0.008606254123151302, 0.033233392983675, -0.0037005168851464987, -0.016019266098737717, 0.03216709941625595, -0.03141503781080246, -0.0033981846645474434, -0.019669579342007637, ...],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.033898260444402695, 0.03424730896949768, -0.003590261796489358, -0.0370299257338047, -0.03892354667186737, -0.012673800811171532, -0.015667568892240524, -0.027449103072285652, 0.02314266748726368, ...],
        ...
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.021719539538025856, -0.012648873962461948, 0.018876560032367706, 0.024806635454297066, 0.019822290167212486, 0.01990167237818241, 0.017527807503938675, 0.00470435805618763, ...],
        ...
      ]
    >
  },
  "causal_attention_3" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.03307117521762848, 0.052388325333595276, -0.052469655871391296, 0.04993458092212677, -0.008502721786499023, 0.04863150417804718, -0.04397425055503845, 0.026126205921173096, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [3.547370433807373e-4, 0.012353941798210144, 0.020879730582237244, 2.934783697128296e-4, -0.03231169283390045, 0.04469521343708038, -0.01846630871295929, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.00862903892993927, 0.047613129019737244, 0.01889754831790924, -0.043374285101890564, -0.0512593537569046, 0.01298445463180542, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0029682815074920654, -0.031321972608566284, -0.003879234194755554, 0.05914776027202606, -0.035743147134780884, ...],
        ...
      ]
    >
  },
  "normalization_21" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_5" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.027845636010169983, 0.04923565685749054, 0.013799235224723816, -0.004870876669883728, 0.05626256763935089, 0.045031070709228516, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.034384652972221375, -0.044549256563186646, 3.282874822616577e-4, 0.01728612184524536, -0.03978864848613739, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.008495509624481201, -0.025550737977027893, -0.012447208166122437, -0.015201836824417114, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.006098702549934387, -0.058861151337623596, -0.05091048777103424, ...],
        ...
      ]
    >
  },
  "dense_24" => %{
    "kernel" => #Nx.Tensor<
      f32[768][50257]
      EXLA.Backend
      [
        [0.0011196397244930267, -0.006019069813191891, -0.0032777772285044193, 0.0026513785123825073, -0.010327205993235111, ...],
        ...
      ]
    >
  },
  "dense_14" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.002042062347754836, 0.007567929103970528, -0.02847537398338318, ...],
        ...
      ]
    >
  },
  "causal_attention_4" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.02169060707092285, 0.025045320391654968, 0.03465403616428375, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0044701844453811646, 0.012541219592094421, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.02636946737766266, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_10" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >
  },
  "dense_7" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_22" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [...]
    >,
    ...
  },
  "causal_attention_2" => %{...},
  ...
}
logit = predict_fn.(params, batch)
#Nx.Tensor<
  f32[2][4][50257]
  EXLA.Backend
  [
    [
      [0.2803703248500824, -0.07315270602703094, 0.3577398955821991, -0.07260004431009293, 0.1903536319732666, 0.07483126223087311, 0.2513326406478882, -0.11101536452770233, 0.17445151507854462, 0.10261347889900208, -0.12078019231557846, -0.049824073910713196, 0.18545812368392944, 0.013964533805847168, -0.13906230032444, 0.0896649956703186, 0.03434973955154419, -0.24493391811847687, 0.4062414765357971, 0.0967893898487091, 0.44506508111953735, 0.3618185222148895, 0.4185892939567566, -0.1302228718996048, 0.3377799987792969, 1.1286139488220215e-4, -0.1995621919631958, 0.10705672204494476, -0.06998956948518753, 0.43495500087738037, 0.06465739756822586, -0.044191308319568634, 0.042275916785001755, 0.34331750869750977, -0.015286875888705254, 0.3468380570411682, 0.007815178483724594, -0.2892349660396576, 0.10641936957836151, 0.15051062405109406, 0.27116501331329346, -0.07599429786205292, 0.16037455201148987, 0.023972883820533752, -0.22748763859272003, -0.011132504791021347, -0.06328065693378448, -0.09093485027551651, -0.1781986951828003, 0.19906549155712128, ...],
      ...
    ],
    ...
  ]
>

The forward method takes a batch of input token indices, computes their embeddings, applies the positional embeddings, passes the sequence through the transformer blocks, normalizes the final output, and then computes the logits, representing the next token’s unnormalized probabilities.

Weight tying reduces the overall memory footprint and computational complexity of the model. However, in my experience, using separate token embedding and out- put layers results in better training and model performance; hence, we use separate layers in our GPTModel implementation.

4.7 Generating text

The step-by-step process by which an LLM generates text, one token at a time. Starting with an initial input context (“Hello, I am”), the model predicts a subsequent token during each iteration, appending it to the input context for the next round of prediction. The first iteration adds “a,” the second “model,” and the third “ready,” progressively building the sentence.

In each step, the model outputs a matrix with vectors representing potential next tokens. The vector corresponding to the next token is extracted and converted into a probability distribution via the softmax function. Within the vector containing the resulting probability scores, the index of the highest value is located, which translates to the token ID. This token ID is then decoded back into text, producing the next token in the sequence. Finally, this token is appended to the previous inputs, forming a new input sequence for the subsequent iteration. This step-by-step process enables the model to generate text sequentially, building coherent phrases and sentences from the initial input context.

The process begins by encoding the input text into token IDs, which are then fed into the GPT model. The outputs of the model are then converted back into text and appended to the original input text.

# Compute the next token id to be concatenated to the input.
predicted_new_token = 
  logit[[0..1, -1]] # Get last element of the vector.
  |> Axon.Layers.softmax(axis: -1)
  |> Nx.argmax(axis: -1)
  |> Nx.reshape({2, 1})
#Nx.Tensor<
  s64[2][1]
  EXLA.Backend
  [
    [26145],
    [20618]
  ]
>
token_ids_tensor = Nx.concatenate([batch, predicted_new_token], axis: 1)
token_ids_list = Nx.to_list(token_ids_tensor)
[[6109, 3629, 6100, 345, 26145], [6109, 1110, 6622, 257, 20618]]
for token_ids <- token_ids_list do
  {:ok, text} = Tiktoken.decode("gpt-3.5-turbo", token_ids)
  text
end
["Web oftenpite,\n%',\n", "WebCom seconds    535"]

Summary

  • Layer normalization stabilizes training by ensuring that each layer’s outputs have a consistent mean and variance.
  • Shortcut connections are connections that skip one or more layers by feeding the output of one layer directly to a deeper layer, which helps mitigate the vanishing gradient problem when training deep neural networks, such as LLMs.
  • Transformer blocks are a core structural component of GPT models, combining masked multi-head attention modules with fully connected feed forward networks that use the GELU activation function.
  • GPT models are LLMs with many repeated transformer blocks that have millions to billions of parameters.
  • GPT models come in various sizes, for example, 124, 345, 762, and 1,542 million parameters, which we can implement with the same GPTModel Python class.
  • The text-generation capability of a GPT-like LLM involves decoding output tensors into human-readable text by sequentially predicting one token at a time based on a given input context.
  • Without training, a GPT model generates incoherent text, which underscores the importance of model training for coherent text generation.