Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 5: Pretraining on unlabeled data

ch5.livemd

Chapter 5: Pretraining on unlabeled data

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:tiktoken, "~> 0.3.2"},
  {:table_rex, "~> 3.1.1"},
  {:bumblebee, "~> 0.6.0"},
  {:kino_vega_lite, "~> 0.1.11"}
])

Nx.global_default_backend(EXLA.Backend)

Introduction

This section covers:

  • Pretraining the LLM.
  • Implementing the training code.
  • Evaluating the performance process.
  • Saving and loading model weights.

In the context of LLMs and other deep learning models, weights refer to the trainable parameters that the learning process adjusts. These weights are also known as weight parameters or simply parameters.

tokenizer = "gpt-3.5-turbo"
gpt_config_124m = [
  vocab_size: 50257,
  context_length: 1024,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.1,
  qkv_bias: false
]
[
  vocab_size: 50257,
  context_length: 1024,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.1,
  qkv_bias: false
]
defmodule Transformer.Layers do
  import Nx.Defn
  
  def attention(%Axon{} = input, opts \\ []) do
    #opts = Keyword.validate!(opts, [:name, :d_in, :d_out, :num_heads])
    head_dim = div(opts[:d_out], opts[:num_heads])
    w_query = Axon.param("w_query", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_key = Axon.param("w_key", fn _ -> {opts[:d_in], opts[:d_out]} end)
    w_value = Axon.param("w_value", fn _ -> {opts[:d_in], opts[:d_out]} end)
    out_proj = Axon.param("out_proj", fn _ -> {opts[:d_out], opts[:d_out]} end)

    Axon.layer(
      &attention_impl/6,
      [input, w_query, w_key, w_value, out_proj],
      name: opts[:name],
      op_name: :causal_attention,
      head_dim: head_dim,
      num_heads: opts[:num_heads]
    )
  end

  #defnp attention_impl(input, w_query, w_key, w_value, head_dim, num_heads, _opts \\ []) do
  defnp attention_impl(input, w_query, w_key, w_value, out_proj, opts \\ []) do
    {b, num_tokens, _d_in} = Nx.shape(input)
    keys = Nx.dot(input, w_key)
    queries = Nx.dot(input, w_query)
    values = Nx.dot(input, w_value)
    d_k = Nx.axis_size(keys, -1)

    keys_reshaped = 
      keys
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    queries_reshaped = 
      queries
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])
    
    values_reshaped = 
      values
      |> Nx.reshape({b, num_tokens, opts[:num_heads], opts[:head_dim]}) 
      |> Nx.transpose(axes: [0, 2, 1, 3])

    attn_score =
      keys_reshaped
      |> Nx.transpose(axes: [0, 1, 3, 2])
      |> then(&Nx.dot(queries_reshaped, [3], [0, 1], &1, [2], [0, 1]))

    simple_mask =
      attn_score
      |> then(&Nx.broadcast(Nx.Constants.infinity(), &1))
      |> Nx.triu(k: 1)

    masked = Nx.multiply(simple_mask, -1) |> Nx.add(attn_score)

    attn_weights =
      masked
      |> Nx.divide(Nx.pow(d_k, 0.5))
      |> Axon.Activations.softmax(axis: -1)
    
    context_vec =
      attn_weights
      |> Nx.dot([3], [0, 1], values_reshaped, [2], [0, 1])
      |> Nx.transpose(axes: [0, 2, 1, 3])

    context_vec
    |> Nx.reshape({b, num_tokens, opts[:num_heads] * opts[:head_dim]})
    |> Nx.dot(out_proj)
  end

  def shortcut(x, layer_impl, opts \\ []) when is_function(layer_impl) do
    with {:arity, arity} <- Function.info(layer_impl, :arity),
          layer_output <- execute_layer(x, layer_impl, opts, arity),
          output <- shortcut_impl(x, layer_output, opts) do
      output
    end    
  end

  defp execute_layer(x, layer_impl, _opts, 1), do: layer_impl.(x)
  defp execute_layer(x, layer_impl, opts, _arity), do: layer_impl.(x, opts)
  defp shortcut_impl(x, layer_output, opts) do
    use_shortcut? = Keyword.get(opts, :use_shortcut, false)
    if use_shortcut?, 
      do: Axon.add(x, layer_output),
      else: layer_output
  end

  def normalization(%Axon{} = input, opts \\ []) do
    #opts = Keyword.validate!(opts, [:name, :eps, :emb_dim])
    eps = Keyword.get(opts, :eps, 1.00e-5)
    scale = Axon.param("scale", {opts[:emb_dim]}, initializer: &amp;ones(&amp;1, type: &amp;2))
    shift = Axon.param("shift", {opts[:emb_dim]}, initializer: &amp;zeros(&amp;1, type: &amp;2))

    Axon.layer(
      &amp;normalization_impl/4,
      [input, scale, shift],
      name: opts[:name],
      op_name: :normalization,
      eps: eps
    )
  end

  defp ones(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(1)
  end

  defp zeros(shape, opts) do
    opts = Keyword.validate!(opts, [:type])
    Nx.iota(shape, type: opts[:type]) |> Nx.fill(0)
  end

  defnp normalization_impl(input, scale, shift, opts \\ []) do
    mean = Nx.mean(input, axes: [-1], keep_axes: true)
    variance = Nx.variance(input, axes: [-1], keep_axes: true)
    denominator = variance |> Nx.add(opts[:eps]) |> Nx.sqrt()

    input
    |> Nx.subtract(mean)
    |> Nx.divide(denominator)
    |> Nx.multiply(scale)
    |> Nx.add(shift)
  end

  def pos_embedding(%Axon{} = x, vocab_size, embedding_size, opts \\ []) do
    opts = Keyword.validate!(opts, [:name, kernel_initializer: :uniform])

    kernel_shape = &amp;Axon.Shape.embedding_kernel(&amp;1, vocab_size, embedding_size)

    kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

    Axon.layer(&amp;pos_embedding_impl/3, [x, kernel], name: opts[:name], op_name: :pos_embedding)
  end

  defnp pos_embedding_impl(x, kernel, _opts \\ []) do
    {_batch_size, sequence_size} = Nx.shape(x)
    input = Nx.iota({1, sequence_size})
    Nx.take(kernel, input, axis: 0)
  end

  def feedforward(input, emb_dim) do
    input
    |> Axon.dense(4*emb_dim)
    |> Axon.activation(:gelu)
    |> Axon.dense(emb_dim)
  end

  def feedforward_block(input, opts) do
    input
    |> normalization(opts)
    |> feedforward(opts[:emb_dim])
    |> Axon.dropout(rate: opts[:drop_rate])
  end

  def attention_block(input, opts) do
    input
    |> normalization(opts)
    |> attention(d_in: opts[:emb_dim], d_out: opts[:emb_dim], num_heads: opts[:n_heads] )
    |> Axon.dropout(rate: opts[:drop_rate])
  end

  def block(input, opts \\ []) do
    input
    |> shortcut(&amp;attention_block(&amp;1, opts), use_shortcut: true)
    |> shortcut(&amp;feedforward_block(&amp;1, opts), use_shortcut: true)
  end
end
{:module, Transformer.Layers, <<70, 79, 82, 49, 0, 0, 51, ...>>, {:block, 2}}
defmodule MyGPT do
  @gpt_config_124m gpt_config_124m
  def model(input_shape \\ {2, 4, 768}, opts \\ @gpt_config_124m) do
    Axon.input("sequence", shape: input_shape)
    |> embedding_block(opts)
    |> Axon.dropout(rate: opts[:drop_rate])
    |> transformer_blocks(12, opts)
    |> Transformer.Layers.normalization(opts)
    |> Axon.dense(opts[:vocab_size], use_bias: false)
  end

  def embedding_block(input, opts) do
    token_emb = Axon.embedding(input, opts[:vocab_size], opts[:emb_dim])
    pos_emb = Transformer.Layers.pos_embedding(input, opts[:context_length], opts[:emb_dim])

    Axon.add(token_emb, pos_emb)
  end

  def transformer_blocks(input, n_blocks, transformer_opts) do
    for _n_block <- 1..n_blocks, reduce: input do
      model_acc ->
        Transformer.Layers.block(model_acc, transformer_opts)
    end
  end

  def text_to_token_ids(tokenizer, texts) when is_list(texts) do
    token_ids_list = 
      for text <- texts do
        {:ok, token_ids} = text_to_token_ids(tokenizer, text)
        token_ids
      end
    Nx.stack(token_ids_list, axis: 1) |> Nx.squeeze()
  end

  def text_to_token_ids(tokenizer, text) do
    {:ok, tokens} = Tiktoken.encode(tokenizer, text)
    {:ok, Nx.tensor(tokens, type: :s64) |> Nx.new_axis(0)}
  end

  def token_ids_to_text(tokenizer, token_ids) do
    tokens_ids = Nx.to_flat_list(token_ids)
    Tiktoken.decode(tokenizer, tokens_ids)
  end

  def generate_tokens(predict_fn, model_params, input, max_new_token) when is_function(predict_fn) do
    generate_tokens_impl(predict_fn, model_params, input, max_new_token)
  end

  def generate_tokens_with_model(model, model_params, input, max_new_token) when model_params == %{} do
    {init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
    template = Nx.template(Nx.shape(input), :s64)
    init_model_params = init_fn.(template, model_params)
    generate_tokens_impl(predict_fn, init_model_params, input, max_new_token)
  end

  def generate_tokens_with_model(model, model_params, input, max_new_token) do
    {_init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
    generate_tokens_impl(predict_fn, model_params, input, max_new_token)
  end

  defp generate_tokens_impl(predict_fn, model_params, input, max_new_token) do
    for _new_token_index <- 1..max_new_token, reduce: input do
      input_acc ->
        logit = predict_fn.(model_params, input_acc)

        # Get last element of the vector.
        predicted_new_token =
          logit[[.., -1]]
          |> Axon.Layers.softmax(axis: -1)
          |> Nx.argmax(axis: -1)
          |> Nx.new_axis(0)

        Nx.concatenate([input_acc, predicted_new_token], axis: 1)
    end
  end
end
{:module, MyGPT, <<70, 79, 82, 49, 0, 0, 25, ...>>, {:generate_tokens_impl, 4}}

5.1 Evaluating generative text models

gpt_config_124m = [
  vocab_size: 50257,
  context_length: 256,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0,
  qkv_bias: false
]
[
  vocab_size: 50257,
  context_length: 256,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0,
  qkv_bias: false
]
model = MyGPT.model()
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
template = Nx.template({1, 4}, :s64)
params = init_fn.(template, %{})
%{
  "pos_embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[1024][768]
      EXLA.Backend
      [
        [-0.004166834056377411, 5.558252087212168e-5, 0.006918067578226328, 0.008123989216983318, 0.0018085908377543092, -0.009919502772390842, -0.0039427136071026325, -0.0012765335850417614, 0.008090266957879066, 0.007765047252178192, 0.008992955088615417, 0.006326744332909584, 0.006782951299101114, 9.583806968294084e-4, -0.0012600517366081476, 0.007900390774011612, -0.009029812179505825, 0.00775456428527832, 0.0036909293849021196, 0.009727797470986843, 0.0020847702398896217, -0.0038497995119541883, 0.005077249836176634, -0.006431164685636759, 0.006501109339296818, 0.0025014327839016914, 0.007404351141303778, 0.005517244338989258, 0.0047699641436338425, 0.006847111973911524, -0.006225829012691975, 0.0030291485600173473, -0.008682658895850182, 0.009832290932536125, 0.008064300753176212, -0.008270682767033577, 0.004674255847930908, 0.0036342358216643333, 0.008963997475802898, 0.003732154378667474, 0.008757982403039932, -2.5937557802535594e-4, -0.009371106512844563, 0.004248545039445162, -0.00806444138288498, -0.0038590501062572002, -0.008312546648085117, -0.00812917947769165, ...],
        ...
      ]
    >
  },
  "normalization_6" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_4" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.007316610775887966, 0.015355171635746956, -0.024541445076465607, -0.01522262766957283, 0.03305758163332939, -0.002012224867939949, 0.01904001645743847, -0.0016197205986827612, 0.03809060528874397, -0.021260753273963928, -0.022825341671705246, -0.008677077479660511, -0.006486063823103905, -0.00958315096795559, -0.019128048792481422, -0.014574884437024593, -0.0013467168901115656, -0.004243602976202965, 0.014682010747492313, -0.01667947694659233, -0.036754999309778214, 0.026166953146457672, 0.021602800115942955, -0.01387953944504261, 0.02438930794596672, -0.011118466965854168, -0.016139831393957138, -0.009837701916694641, 0.011002189479768276, -0.033176518976688385, -0.021040309220552444, -0.026683131232857704, -0.03488384932279587, -0.020158663392066956, -0.03940330818295479, 0.0015812128549441695, -0.007541559636592865, 0.011250671930611134, 0.007062220014631748, 0.022234296426177025, 0.02331043966114521, -0.014372176490724087, 8.054308709688485e-4, 0.01346514280885458, 0.012399421073496342, ...],
        ...
      ]
    >
  },
  "normalization_5" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_20" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.03622298687696457, 0.02234366536140442, -0.008425580337643623, -0.004621932748705149, -0.01705941930413246, -0.002757453126832843, 0.010120686143636703, -0.022151663899421692, -0.029551802203059196, 0.0015321027021855116, -0.019663434475660324, -0.038496822118759155, 0.01596236228942871, -0.02615923434495926, 0.03454967215657234, -0.0033944619353860617, 0.004495373461395502, 0.0062217493541538715, 0.007031496614217758, -6.067379144951701e-4, -0.0036599356681108475, 0.014366408810019493, 0.03911951556801796, -0.03523853048682213, -0.004185822326689959, -0.022292423993349075, 0.019917815923690796, -0.03139081597328186, 0.03846295177936554, 7.955447654239833e-4, -0.019002949818968773, -2.882617700379342e-4, 0.01781204529106617, -0.01457664743065834, 0.037084680050611496, 0.02064569480717182, -0.01873495988547802, 0.026645952835679054, -0.014967040158808231, -0.01808040216565132, -0.0158998966217041, -4.6918989391997457e-4, -0.0072336201556026936, ...],
        ...
      ]
    >
  },
  "normalization_1" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_9" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_22" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.013211779296398163, -0.038686223328113556, -0.03646089509129524, -0.03547561168670654, -0.03136007487773895, 0.03803395479917526, 0.034124672412872314, 0.0019030723487958312, 0.039203938096761703, 0.02367444522678852, 0.007645085919648409, 0.03256858512759209, -0.03359496593475342, 7.508546113967896e-4, -0.034337423741817474, -0.023720962926745415, -0.03365522623062134, 0.019042164087295532, -0.020887484773993492, -0.0030847976449877024, 0.03787440061569214, 0.013972415588796139, -0.03296968340873718, -0.004996474366635084, 0.016596797853708267, -0.004487014375627041, -0.006767191458493471, 0.004086772911250591, 0.032751210033893585, -0.009971262887120247, -0.02156519703567028, 0.0024161040782928467, -0.019541341811418533, -0.01818450354039669, 0.0018818487878888845, 0.03243156522512436, 0.004533287603408098, -0.022080792114138603, 0.026794254779815674, -0.025121869519352913, ...],
        ...
      ]
    >
  },
  "causal_attention_0" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.03389421105384827, 0.04651099443435669, -0.004578709602355957, 0.05322122573852539, -0.006341084837913513, -0.04544401168823242, 0.0211334228515625, 0.011194080114364624, 0.05034850537776947, -0.04450160264968872, -0.06052732467651367, 0.0410953015089035, 0.004480987787246704, -0.014307916164398193, -0.06064000725746155, -0.06154410541057587, -0.008748799562454224, -0.020162060856819153, 0.03290221095085144, 0.02467593550682068, -0.006243795156478882, -0.0019895434379577637, -0.029925718903541565, 0.0034829825162887573, -0.014508232474327087, 0.04588325321674347, -0.02697893977165222, 0.06067804992198944, -0.01538117229938507, -0.055520057678222656, 0.04659220576286316, 0.0068100690841674805, -0.002129107713699341, 0.012684687972068787, 0.036468952894210815, -0.005811750888824463, 0.023060843348503113, -0.021330252289772034, 0.02877815067768097, 0.034971222281455994, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.026694804430007935, -0.0156431645154953, 0.005123689770698547, -0.051212653517723083, 0.006300434470176697, -0.010585159063339233, 0.06003814935684204, -0.03139650821685791, 0.026933014392852783, 0.014955848455429077, 0.014135494828224182, -0.0593930184841156, 0.03183335065841675, -0.020880088210105896, 0.006442844867706299, -0.03610377013683319, -0.04488801956176758, 0.00456276535987854, 0.05108395218849182, -0.029768452048301697, 0.005766123533248901, -0.010115623474121094, -0.022595539689064026, -0.054568707942962646, 0.008954733610153198, 0.02461846172809601, 0.028902962803840637, 0.0251777321100235, -0.0038338154554367065, -0.0594489723443985, 0.022434815764427185, 0.04677131772041321, -0.05091351270675659, 0.04277119040489197, -0.06151154637336731, -0.021195203065872192, 5.386769771575928e-5, -0.03812660276889801, -0.0039939284324646, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.02401629090309143, -0.02956591546535492, 0.061295285820961, -0.021912038326263428, -0.053771957755088806, -0.047283291816711426, -0.021504893898963928, -0.03588941693305969, -0.05604234337806702, 0.04097875952720642, 0.051357969641685486, 0.04518897831439972, -0.04817293584346771, 0.05209130048751831, 0.057313308119773865, -0.03741718828678131, 0.03765517473220825, 0.06140288710594177, 0.05412471294403076, -0.04883106052875519, 0.016423314809799194, 0.056240975856781006, -0.027981281280517578, -0.017649903893470764, 0.005771517753601074, -0.05313637852668762, -0.030975595116615295, -5.39928674697876e-4, 0.012884467840194702, 0.058664143085479736, -0.029691308736801147, 0.019715651869773865, 0.06167466938495636, -0.03277812898159027, 0.021899685263633728, -0.004622027277946472, -0.0038959532976150513, 0.04854552447795868, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.061722636222839355, 0.014092415571212769, 0.014214977622032166, 0.025194033980369568, 0.005165129899978638, 0.025518834590911865, -0.05198174715042114, -0.014645203948020935, 0.030731558799743652, 0.012704804539680481, 0.025482699275016785, 0.035750120878219604, -0.006144717335700989, 0.0290277898311615, -0.027165263891220093, -0.02803950011730194, 0.054379433393478394, -0.06212638318538666, 0.04287844896316528, -0.04524946212768555, -0.04514627158641815, 0.04273821413516998, 0.009859591722488403, 0.00796559453010559, 0.06041954457759857, -0.042293041944503784, 0.05158552527427673, 0.049761489033699036, -0.06142106652259827, 0.02010442316532135, 0.014893487095832825, -0.029858902096748352, 0.048858001828193665, 0.04792974889278412, -0.05165189504623413, 0.03463757038116455, 0.057726532220840454, ...],
        ...
      ]
    >
  },
  "normalization_14" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_7" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_11" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_2" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_12" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.008988947607576847, -0.0061833830550313, 0.004445867612957954, 0.035713717341423035, -0.02435637079179287, -0.0050144558772444725, -0.02177242748439312, -0.038602836430072784, 0.037907566875219345, -0.004283269867300987, 0.031885113567113876, 0.020809074863791466, -0.01685102842748165, 0.013179557397961617, -0.005057223606854677, -0.038501184433698654, -0.00664452463388443, -0.015255838632583618, 0.02304474078118801, -0.026754803955554962, 0.014159554615616798, -0.030819524079561234, -0.00975214783102274, -0.027543487027287483, 0.021747123450040817, 0.025858314707875252, -0.03038638085126877, 0.012028036639094353, -0.028919778764247894, 6.43228879198432e-4, 0.016414616256952286, -0.02885948121547699, -0.03627284988760948, 0.013899820856750011, ...],
        ...
      ]
    >
  },
  "dense_23" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.01389195118099451, 0.028496315702795982, 0.019149629399180412, -0.015082176774740219, 0.019566090777516365, -0.013020898215472698, 0.019521258771419525, -0.004951020702719688, -0.017445288598537445, -0.03372013196349144, 0.03417432680726051, -0.009944036602973938, -0.00781666487455368, 0.03892479091882706, 0.0027626364026218653, -0.020471522584557533, 0.03167443349957466, -0.012267885729670525, -0.018630603328347206, -0.012284058146178722, -0.01911494880914688, -0.0036643368657678366, 0.026073435321450233, 0.013562231324613094, -0.033684246242046356, 0.03922832012176514, 0.01838323473930359, 0.030265383422374725, -0.02810743823647499, 0.03593805059790611, 0.012999542988836765, 0.018777500838041306, 0.014124138280749321, ...],
        ...
      ]
    >
  },
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[50257][768]
      EXLA.Backend
      [
        [0.009341144002974033, 0.007365426979959011, 0.00825655460357666, -2.8293608920648694e-4, 0.004749185871332884, -0.0023499035742133856, -0.002149598440155387, 0.0024735117331147194, -0.004771864507347345, 0.0032234524842351675, -0.007733702659606934, -0.003493587952107191, -0.0013423752970993519, -0.001936259213835001, -0.009511223062872887, 0.0011167096672579646, 0.003883607219904661, 0.005537152290344238, -0.008645045571029186, 0.00897135492414236, 0.008197367191314697, 0.002842607442289591, -0.006284489296376705, 0.0016422390472143888, 0.002079141093418002, 0.005127396434545517, 0.0013609599554911256, -0.004524736199527979, -0.0052507612854242325, 0.005974554922431707, 0.009530248120427132, 0.009378919377923012, -0.004144019912928343, ...],
        ...
      ]
    >
  },
  "normalization_19" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_6" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.038281723856925964, 0.03060060739517212, 0.055243924260139465, -0.0010612159967422485, 0.010932952165603638, 0.0380278080701828, 0.039625704288482666, 0.012449786067008972, 0.04163214564323425, 0.011243581771850586, -0.017265647649765015, -0.029650554060935974, -0.026119768619537354, 0.011231780052185059, 0.06180793046951294, 0.05046388506889343, 0.0018288791179656982, -0.011008426547050476, -0.05981588363647461, 0.032344549894332886, -0.05203762650489807, -0.014526709914207458, 0.03726394474506378, -0.016059458255767822, -0.025209084153175354, -0.004407376050949097, -0.042520150542259216, -0.014836296439170837, -0.05615831911563873, 0.012479528784751892, -0.04188920557498932, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.014105960726737976, 0.025272339582443237, 0.039366647601127625, 0.04212130606174469, -0.05741038918495178, -0.010394155979156494, 0.05349820852279663, -0.05640338361263275, 0.05356724560260773, -0.030641287565231323, 0.023347437381744385, -2.7091801166534424e-4, 0.04832763969898224, -0.015682756900787354, 0.003194168210029602, -0.022340387105941772, 0.040028661489486694, -0.017657384276390076, 0.012668401002883911, 0.029825359582901, 0.016716867685317993, -0.028231605887413025, -0.061023563146591187, 0.040525972843170166, 0.049099400639534, 0.04155528545379639, -0.052004411816596985, 0.047050416469573975, -0.049679338932037354, 0.040252625942230225, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.01860181987285614, -0.027480706572532654, 0.031518712639808655, -0.04549586772918701, 0.03295879065990448, 0.03180180490016937, 0.007855981588363647, 0.052125051617622375, -0.010133251547813416, 0.05646832287311554, 0.02763456106185913, 0.014156967401504517, 0.060103341937065125, -0.03327015042304993, -0.021565958857536316, -0.036019325256347656, -0.007117718458175659, -0.05001211166381836, -0.008163869380950928, -0.06163521111011505, 0.04946871101856232, 0.03733907639980316, 0.009067490696907043, 0.026343733072280884, 0.03723730146884918, -0.018264561891555786, 0.02422654628753662, 0.008374512195587158, 0.04460504651069641, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.046486884355545044, -0.05079798400402069, 0.03470088541507721, 0.04372638463973999, 0.061747580766677856, -0.024993568658828735, -0.02915513515472412, 0.058661088347435, -0.00888688862323761, -0.01159195601940155, -0.007967159152030945, -0.027336061000823975, 0.010703086853027344, 0.013705432415008545, 0.014023050665855408, -0.04765051603317261, -0.020531728863716125, 0.05245104432106018, 0.01385180652141571, -0.0020856261253356934, 0.01170627772808075, -0.018050864338874817, -0.0194767564535141, -0.013664931058883667, -0.06227518618106842, -0.016330137848854065, -0.060487210750579834, 0.01637732982635498, ...],
        ...
      ]
    >
  },
  "causal_attention_7" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.055874019861221313, -0.009761840105056763, 0.0310947448015213, 0.00846041738986969, 0.05848193168640137, -0.0280439555644989, 0.002997070550918579, 0.04026903212070465, 0.035716861486434937, -0.06076526641845703, -0.020011886954307556, 0.0373322069644928, 0.045154377818107605, 0.013299152255058289, 0.057207703590393066, -0.020929694175720215, 0.02732427418231964, 0.04890526831150055, 0.052000924944877625, -0.043057799339294434, -0.010325685143470764, -0.008965983986854553, -0.024845615029335022, 0.019196465611457825, -0.023155972361564636, 0.05365733802318573, 0.01386803388595581, -0.04518139362335205, 0.02958647906780243, 0.02864210307598114, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.040406033396720886, 0.016526684165000916, -0.0033710896968841553, -0.043799445033073425, 0.02486005425453186, 0.05860051512718201, 0.018847361207008362, 0.01113341748714447, -0.0016693472862243652, 0.05720651149749756, 0.00262393057346344, -0.06137925386428833, 0.004038423299789429, 0.04122963547706604, -0.01048918068408966, 0.04223752021789551, 0.06142869591712952, -0.034814223647117615, -0.014857932925224304, -0.041082412004470825, -0.05591981112957001, -0.05653798580169678, 0.06095759570598602, -0.05417846143245697, 2.2234022617340088e-4, -0.02513866126537323, -0.024502307176589966, 0.04697492718696594, 0.053062379360198975, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.06247323751449585, 0.03643867373466492, -0.05104561150074005, -0.04741416871547699, -0.053138941526412964, -0.03560763597488403, -0.05929890275001526, -0.062352314591407776, 0.025926724076271057, -0.02425217628479004, -0.0013734400272369385, 0.03181181848049164, 0.027512595057487488, 0.02301667630672455, -7.220059633255005e-4, -0.06135407090187073, 0.048133403062820435, -0.018433496356010437, 0.027993008494377136, -0.04901476204395294, 0.05222088098526001, 0.0039632320404052734, 0.044420480728149414, -0.023455306887626648, -0.010476246476173401, 0.004272446036338806, -0.05783988535404205, 0.012503117322921753, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.01894761621952057, -0.03178654611110687, -0.011216983199119568, 0.009847953915596008, 0.04243092238903046, -0.05269542336463928, -0.03829042613506317, -0.016415312886238098, 0.03241594135761261, 0.03255505859851837, -0.05531121790409088, 0.043683573603630066, 0.0023234039545059204, -0.04956963658332825, 0.056855812668800354, 0.017909914255142212, -0.039806872606277466, -0.03650254011154175, -0.003383934497833252, 0.06122206151485443, -0.045029670000076294, -0.04313218593597412, -0.03592805564403534, -0.05727629363536835, 0.014998644590377808, 0.006697550415992737, 0.022688373923301697, ...],
        ...
      ]
    >
  },
  "causal_attention_1" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0015059411525726318, 0.061351388692855835, 0.02020566165447235, 0.008800521492958069, -0.03287772834300995, 0.0281343013048172, -0.033173367381095886, 0.03154589235782623, -0.020690634846687317, 0.06248064339160919, 0.058673471212387085, 0.058091938495635986, 0.04341328144073486, 0.023070335388183594, -0.05406598746776581, -0.012940526008605957, -0.033596232533454895, 0.029249414801597595, -0.04499354958534241, -0.03973957896232605, 0.03586547076702118, 0.03717055916786194, -0.023941397666931152, -0.0174674391746521, 0.02705080807209015, 0.06142842769622803, 0.03275701403617859, 0.03966856002807617, 0.003020450472831726, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.044230520725250244, 0.045184120535850525, -0.029058068990707397, 0.001558348536491394, 0.05115991830825806, 0.02305695414543152, 0.06219954788684845, 0.03394980728626251, -0.03425164520740509, -0.02188214659690857, -0.0031602829694747925, -0.009219080209732056, 0.018819406628608704, -0.05301015079021454, -0.03463184833526611, -0.054110825061798096, 0.019130408763885498, 0.03492939472198486, 0.03678998351097107, 0.017699211835861206, 0.05030147731304169, -0.04183673858642578, 0.03185127675533295, -0.05556485056877136, 0.006304115056991577, 0.03466491401195526, 0.05986148118972778, -0.003671795129776001, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.054660335183143616, 0.029184728860855103, -0.045949786901474, -0.0016283243894577026, -0.022533699870109558, -0.03693334758281708, 0.014323234558105469, 0.009834527969360352, 0.02217020094394684, 0.05951140820980072, -0.03895184397697449, 0.05900442600250244, -0.008548364043235779, 0.020830363035202026, 0.010709181427955627, 0.02067597210407257, 0.05800652503967285, 0.03209450840950012, -0.01825849711894989, 0.009139835834503174, 0.031897738575935364, 0.04402792453765869, -0.04096783697605133, -0.03625558316707611, -0.004964768886566162, 0.04366588592529297, 0.05359916388988495, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.026089757680892944, -0.040446147322654724, 0.016625910997390747, -0.06044696271419525, 0.05026254057884216, -0.021860748529434204, -0.009792506694793701, 0.06130346655845642, -0.05351586639881134, -0.05926525592803955, -0.02451053261756897, 0.009171128273010254, 0.008893996477127075, 0.011868655681610107, 0.04063501954078674, -0.033051103353500366, 0.03971649706363678, -0.003199264407157898, 0.05400268733501434, -0.018274664878845215, 0.017529502511024475, 0.04050494730472565, -0.0197765976190567, 0.0191190242767334, -0.059699878096580505, -0.013290226459503174, ...],
        ...
      ]
    >
  },
  "dense_15" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.032438941299915314, 0.03649812936782837, -0.003923939075320959, 0.036131374537944794, -2.232150873169303e-4, -0.02680134028196335, -0.022355681285262108, -0.020713115110993385, -0.023222265765070915, -0.02502395212650299, -0.037389472126960754, 0.021023863926529884, -0.026890335604548454, 0.01821730099618435, -0.0014836805639788508, 0.019815966486930847, -0.019417893141508102, -0.023115478456020355, -0.03618929535150528, 0.029499731957912445, -0.023682503029704094, -0.02980787865817547, 0.007050505373626947, 0.011878293938934803, 0.03082136996090412, -0.035970304161310196, 3.776703088078648e-4, ...],
        ...
      ]
    >
  },
  "dense_9" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.020436709746718407, -0.0379643589258194, -0.019004514440894127, 0.03646769002079964, 0.03305423632264137, -0.02296075038611889, -0.03733103349804878, -0.03447183594107628, -0.027706434950232506, -0.002958945231512189, -0.028149612247943878, -0.0035864447709172964, -0.012770174071192741, 0.03061867319047451, -0.031384896486997604, -0.0070895785465836525, -0.006281273439526558, 0.023142922669649124, 0.03576534241437912, 0.03355271741747856, 0.024814363569021225, 0.024204298853874207, -0.02483474835753441, 0.026397593319416046, 0.014722544699907303, 0.027882225811481476, ...],
        ...
      ]
    >
  },
  "normalization_8" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_10" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.037321850657463074, 0.0027015358209609985, -0.05850476026535034, 0.018639564514160156, -0.060859352350234985, -0.053807973861694336, 0.004476219415664673, -0.04803529381752014, -0.03940926492214203, 8.370578289031982e-4, -0.022368773818016052, -0.06079511344432831, -0.005477592349052429, -0.01383201777935028, 0.00915960967540741, -0.010526120662689209, -0.022700220346450806, 0.0581853985786438, 0.0140390545129776, -0.022336676716804504, 0.05976895987987518, -0.027719229459762573, 0.0035603344440460205, 0.053518712520599365, -0.015479490160942078, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.040289491415023804, 0.057741954922676086, -0.011750414967536926, -0.046655043959617615, 0.012294203042984009, -0.022581711411476135, -0.048759788274765015, 0.014177143573760986, 0.007975131273269653, -0.007280498743057251, -0.04563623666763306, -2.446174621582031e-4, 0.0577559769153595, 0.006646156311035156, -0.044609323143959045, -0.04842953383922577, -0.036534637212753296, -0.0376274436712265, 0.036296069622039795, 0.0012816935777664185, 0.01672542095184326, -0.04868960380554199, -0.022868692874908447, -0.012826845049858093, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.029218211770057678, 0.0539281964302063, -0.04343624413013458, -0.008074179291725159, 0.03951875865459442, -0.019812554121017456, -0.03159089386463165, -0.04887785017490387, 0.002741783857345581, 0.05355218052864075, 0.033223241567611694, 0.022561028599739075, -0.014021962881088257, 0.04141530394554138, 0.051442548632621765, -0.036560460925102234, 0.005666598677635193, 0.056872159242630005, -0.014608994126319885, -0.04869993031024933, -0.009834975004196167, -0.05478323996067047, 0.02610105276107788, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.013788864016532898, 0.018146023154258728, 0.051189959049224854, -0.002169981598854065, -0.048365235328674316, -0.02318693697452545, -0.005364611744880676, -0.020872995257377625, -0.04682932794094086, 0.024223119020462036, -0.001954376697540283, 0.0307769775390625, -0.04116518795490265, -0.01797768473625183, 0.057260170578956604, -0.04468195140361786, 0.0021699517965316772, -0.01002289354801178, -0.004810646176338196, 0.0040158480405807495, -0.0344107449054718, -0.046755969524383545, ...],
        ...
      ]
    >
  },
  "causal_attention_8" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.03787374496459961, -0.04748013615608215, 0.009660586714744568, 0.039020463824272156, -0.036205992102622986, -0.015598490834236145, -0.036834731698036194, 0.05476512014865875, -0.018453747034072876, 0.030438482761383057, -0.025715991854667664, 0.010151520371437073, -0.0354621559381485, 0.011766821146011353, -0.037548065185546875, 0.03791961073875427, 0.003928542137145996, 0.027588486671447754, -0.022520676255226135, 0.05944724380970001, 0.009360909461975098, -0.0020271986722946167, -0.002756863832473755, 0.05104468762874603, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.0446426123380661, 0.06190474331378937, -0.010013341903686523, 0.013226494193077087, -0.03333577513694763, 0.057821452617645264, -0.005026504397392273, -0.024243950843811035, -0.029694199562072754, -0.05315878987312317, 0.016660556197166443, -0.055919796228408813, -0.022972270846366882, 0.038437098264694214, 0.019830524921417236, -0.006191760301589966, -0.05791124701499939, -0.056481972336769104, -0.01375626027584076, -0.044888511300086975, -0.04319790005683899, 0.029341578483581543, 0.04664739966392517, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.034531787037849426, 0.0594770610332489, -0.057946234941482544, 0.03379034996032715, -0.06057041883468628, -0.020668447017669678, 0.025117143988609314, -0.04090839624404907, 0.025518164038658142, -6.037354469299316e-4, 0.05322335660457611, 0.02600027620792389, -0.058153584599494934, -0.06190285086631775, 0.02267526090145111, -0.013153016567230225, -0.04895910620689392, 0.030276909470558167, 0.006028473377227783, -0.02028815448284149, 0.011408403515815735, 0.042751654982566833, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.003915905952453613, -0.004672065377235413, -0.006892219185829163, -0.0030771642923355103, 0.03323589265346527, -0.057723790407180786, -0.02301366627216339, 0.05571252107620239, -0.012173712253570557, -9.615719318389893e-5, 0.04869668185710907, -0.04045607149600983, -0.004608318209648132, 0.04267892241477966, 0.04858766496181488, -0.03616122901439667, -0.006824299693107605, 0.033615320920944214, 0.03607034683227539, -0.005225852131843567, 0.014217853546142578, ...],
        ...
      ]
    >
  },
  "normalization_12" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.032709091901779175, 0.03732253238558769, -0.025894438847899437, -0.03413146734237671, -0.01966562122106552, -0.007542341947555542, 0.03710388019680977, -8.609495707787573e-4, -0.022763961926102638, 0.028750084340572357, 0.008001541718840599, 0.01947963237762451, 0.018123358488082886, 0.012139761820435524, 0.0361517034471035, 0.027289606630802155, 0.03473292663693428, 0.0154619961977005, -0.014276831410825253, 0.021370546892285347, -0.03366856276988983, ...],
        ...
      ]
    >
  },
  "dense_18" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.027857327833771706, -0.03251894563436508, -0.02887868881225586, 0.0272236168384552, -0.01944740116596222, 0.003230544738471508, -0.02588680572807789, 0.033160071820020676, -0.01840881258249283, 0.0314878411591053, -0.03053002804517746, -0.0382646806538105, 0.005344354547560215, -0.005555506329983473, 0.012634086422622204, 0.0238691046833992, 0.030252980068325996, 0.033063795417547226, 0.03644771873950958, 0.006975817494094372, ...],
        ...
      ]
    >
  },
  "dense_5" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.029138261452317238, 0.01194892916828394, -0.013165006414055824, 0.024874867871403694, -0.018829211592674255, -0.03778839483857155, 0.03692765161395073, -0.028081145137548447, -0.024628592655062675, -0.03685155138373375, -0.011712840758264065, 0.0015483596362173557, 0.01506684347987175, -0.016104310750961304, 8.093042997643352e-4, -0.003082507522776723, 0.009872175753116608, 0.010008074343204498, 0.015622171573340893, ...],
        ...
      ]
    >
  },
  "dense_11" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.0056809913367033005, 0.008367394097149372, -0.02999882586300373, -0.035942114889621735, -0.03549979254603386, 0.028921738266944885, 0.01543259248137474, -0.006477035582065582, -0.007187940180301666, -0.028930341824889183, 0.010264011099934578, -0.020256366580724716, 0.023492414504289627, -0.013146600686013699, -0.01646241545677185, 0.01911110244691372, 0.023302627727389336, 0.019369302317500114, ...],
        ...
      ]
    >
  },
  "normalization_18" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_6" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.01873108558356762, 0.005562593229115009, -0.0030459221452474594, -0.030054861679673195, 0.02643149346113205, 0.00782222580164671, -0.028738228604197502, -0.025134602561593056, 0.030945252627134323, -0.008381690829992294, -0.005828952882438898, -0.02394387684762478, 0.029188040643930435, 0.004821897950023413, -0.001552723115310073, -0.039203014224767685, ...],
        ...
      ]
    >
  },
  "normalization_15" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_23" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_10" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.033126842230558395, 0.03167019039392471, 0.009176330640912056, -0.0015245538670569658, 0.015040370635688305, -0.028125910088419914, -0.0030003273859620094, 0.03882477059960365, -0.03250875696539879, -0.0023387304972857237, 0.027392199262976646, 0.0018461777362972498, -0.02953476272523403, ...],
        ...
      ]
    >
  },
  "dense_13" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.021229662001132965, -0.023160215467214584, -0.02362915128469467, 0.032641906291246414, 0.035696517676115036, -7.597889052703977e-4, -0.030600719153881073, 0.025973094627261162, -0.019005192443728447, 0.025737296789884567, 0.01352191437035799, 0.01484698336571455, ...],
        ...
      ]
    >
  },
  "dense_21" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.02057657577097416, -0.017149468883872032, -0.0011769941775128245, -0.030756918713450432, -0.015738533809781075, -0.010669040493667126, 0.03125642612576485, -0.023122066631913185, 0.0024515208788216114, -0.012312764301896095, -0.02435280755162239, ...],
        ...
      ]
    >
  },
  "dense_8" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.003226200118660927, -0.02678207866847515, -0.009807600639760494, 0.002285153139382601, 0.02261175774037838, -0.025931118056178093, -0.006929176859557629, 0.01971873641014099, 0.016117995604872704, -0.01714867725968361, ...],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.0390324629843235, 0.03699810057878494, -0.03239751234650612, 0.03598751127719879, 0.01669354736804962, -0.024352185428142548, 0.032833945006132126, 0.013367421925067902, 0.037017881870269775, ...],
        ...
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.03536538407206535, 0.03689997270703316, -0.017842505127191544, 0.03011380136013031, -4.258945700712502e-4, -0.02284603752195835, 0.00632682116702199, 0.038707293570041656, ...],
        ...
      ]
    >
  },
  "causal_attention_3" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05904276669025421, -0.03645724058151245, 0.026938289403915405, -0.022431835532188416, 0.05546276271343231, 0.006583884358406067, 0.019741937518119812, -0.06022912263870239, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0072243064641952515, 0.026500403881072998, 0.03577415645122528, 0.05929942429065704, -0.06033988296985626, -0.02542218565940857, 0.06173589825630188, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05389295518398285, -0.05284495651721954, -0.006267979741096497, 0.06011173129081726, -0.043416768312454224, 0.03732989728450775, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.038724884390830994, -0.003134027123451233, -0.04916900396347046, -0.004562869668006897, 0.02397778630256653, ...],
        ...
      ]
    >
  },
  "normalization_21" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_5" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [2.316981554031372e-4, -0.043995022773742676, 0.007348924875259399, -0.02322709560394287, 0.0014031678438186646, -0.04507668316364288, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05333171784877777, -0.004592582583427429, -0.0039058923721313477, -0.004979774355888367, -0.03384046256542206, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.019870340824127197, 0.052950695157051086, 0.031820833683013916, 0.04108205437660217, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05948205292224884, 0.04784397780895233, -0.05489179491996765, ...],
        ...
      ]
    >
  },
  "dense_24" => %{
    "kernel" => #Nx.Tensor<
      f32[768][50257]
      EXLA.Backend
      [
        [-9.80013792286627e-5, -3.2665746402926743e-4, -0.010067851282656193, 0.004079081583768129, 0.008285093121230602, ...],
        ...
      ]
    >
  },
  "dense_14" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.02014617621898651, 0.019741147756576538, 0.0183598343282938, ...],
        ...
      ]
    >
  },
  "causal_attention_4" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.008819535374641418, -0.04820533096790314, -0.04652838408946991, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.033999964594841, 0.019834622740745544, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.028313085436820984, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_10" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >
  },
  "dense_7" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_22" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [...]
    >,
    ...
  },
  "causal_attention_2" => %{...},
  ...
}

5.1.1 Using GPT to generate text

Generating text involves encoding text into token IDs that the LLM processes into logit vectors. The logit vectors are then converted back into token IDs, detokenized into a text representation.

{:ok, input} = MyGPT.text_to_token_ids(tokenizer, "I know everything")
{:ok,
 #Nx.Tensor<
   s64[1][3]
   [
     [40, 1440, 4395]
   ]
 >}
# Nx data accessing example..
example = Nx.iota({2,3,4}) |> IO.inspect()
example[0][0][..]
example[[0,..,0]]
#Nx.Tensor<
  s64[2][3][4]
  [
    [
      [0, 1, 2, 3],
      [4, 5, 6, 7],
      [8, 9, 10, 11]
    ],
    [
      [12, 13, 14, 15],
      [16, 17, 18, 19],
      [20, 21, 22, 23]
    ]
  ]
>
#Nx.Tensor<
  s64[3]
  [0, 4, 8]
>
token_ids = MyGPT.generate_tokens(predict_fn, params, input, 6)
#Nx.Tensor<
  s64[1][9]
  EXLA.Backend
  [
    [40, 1440, 4395, 34633, 15265, 32282, 42655, 32282, 32282]
  ]
>
{:ok, text} = MyGPT.token_ids_to_text(tokenizer, token_ids)
IO.puts(text)
I know everything.tm votre.Number defects.Number.Number
:ok

5.1.2 Calculating the text generation loss

texts = ["every effort moves", "I really like"]
inputs = MyGPT.text_to_token_ids(tokenizer, texts)
#Nx.Tensor<
  s64[2][3]
  [
    [30115, 5149, 11031],
    [40, 2216, 1093]
  ]
>
texts = [" effort moves you", " really like chocolate"]
targets = MyGPT.text_to_token_ids(tokenizer, texts)
#Nx.Tensor<
  s64[2][3]
  [
    [5149, 11031, 499],
    [2216, 1093, 18414]
  ]
>
# Predict next token
logits = predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][3][50257]
  EXLA.Backend
  [
    [
      [-0.01316186785697937, 0.019357509911060333, 0.014704778790473938, 0.020415201783180237, -0.01461111381649971, 0.04169977456331253, -0.02344387397170067, 0.3510170578956604, 0.1543886363506317, -0.16629239916801453, 0.07834135740995407, 0.11583898216485977, 0.7019006609916687, -0.12741434574127197, 0.21520379185676575, -0.07688876986503601, -0.04020526260137558, 0.022352930158376694, 0.023782581090927124, 0.15092983841896057, -0.06123122572898865, 0.017284438014030457, 0.06738126277923584, 0.07279197871685028, 0.03316458314657211, -0.17335070669651031, 0.23228216171264648, 0.11822329461574554, 0.23295998573303223, 0.41274312138557434, -6.950795650482178e-4, -0.14871369302272797, 0.054689787328243256, -0.12075653672218323, -0.02408362552523613, 0.31010890007019043, -0.0019266977906227112, 0.09670630842447281, -0.09819215536117554, -0.1528632938861847, -0.28630802035331726, 0.13456755876541138, 0.014362037181854248, -0.06235301122069359, -0.1348976194858551, 0.017254939302802086, 0.09296403080224991, 0.09324396401643753, -0.0334462970495224, -0.006866224110126495, ...],
      ...
    ],
    ...
  ]
>
predicted_new_token =
  logits
  |> Axon.Layers.softmax(axis: -1)
  |> Nx.argmax(axis: -1)
#Nx.Tensor<
  s64[2][3]
  EXLA.Backend
  [
    [26389, 46140, 19850],
    [2293, 34633, 48547]
  ]
>
{:ok, targets_text} = MyGPT.token_ids_to_text(tokenizer, targets[0])
{:ok, outputs_text} = MyGPT.token_ids_to_text(tokenizer, predicted_new_token[0])

IO.inspect(targets_text, label: "Targets batch 1")
IO.inspect(outputs_text, label: "Outputs batch 1")
Targets batch 1: " effort moves you"
Outputs batch 1: " >\n\nrale/null"
" >\n\nrale/null"

Before training, the model produces random next-token probability vectors. The goal of model training is to ensure that the probability values corresponding to the highlighted target token IDs are maximized.

probas = Axon.Layers.softmax(logits, axis: -1)
#Nx.Tensor<
  f32[2][3][50257]
  EXLA.Backend
  [
    [
      [1.9341214283485897e-5, 1.9980516299256124e-5, 1.9887769667548127e-5, 2.000166023208294e-5, 1.9313203665660694e-5, 2.043195127043873e-5, 1.914336644404102e-5, 2.7838423193315975e-5, 2.286914605065249e-5, 1.6595104170846753e-5, 2.11944952752674e-5, 2.200432754762005e-5, 3.953952545998618e-5, 1.7252996258321218e-5, 2.4303100872202776e-5, 1.8147111404687166e-5, 1.8825172446668148e-5, 2.004045745707117e-5, 2.006912836804986e-5, 2.2790185539633967e-5, 1.843348582042381e-5, 1.9939137928304262e-5, 2.0963469069101848e-5, 2.107720501953736e-5, 2.025830326601863e-5, 1.6478383258800022e-5, 2.4721721274545416e-5, 2.2056856323615648e-5, 2.4738485080888495e-5, 2.9610921046696603e-5, 1.9583847461035475e-5, 1.6889403923414648e-5, 2.0699091692222282e-5, 1.736824560794048e-5, 1.9131122826365754e-5, 2.6722582333604805e-5, 1.9559740394470282e-5, 2.1587326045846567e-5, 1.7764605217962526e-5, 1.681946378084831e-5, 1.4718307284056209e-5, 2.2420319510274567e-5, 1.988095391425304e-5, 1.8412818462820724e-5, 1.7124368241638876e-5, 1.9938550394726917e-5, 2.1506693883566186e-5, 2.1512714738491923e-5, 1.8952840036945418e-5, 1.9463363059912808e-5, ...],
      ...
    ],
    ...
  ]
>
t = Nx.iota({3, 1000}, type: :s64)
Nx.take_along_axis(t, Nx.tensor([[0], [0], [0]]), axis: 1)
#Nx.Tensor<
  s64[3][1]
  [
    [0],
    [1000],
    [2000]
  ]
>
text_index = 0
target_1 = Nx.reshape(targets[text_index], {3,1}) |> IO.inspect()
target_probas_1 = Nx.take_along_axis(probas[text_index], target_1, axis: 1) |> Nx.reshape({3})

text_index = 1
target_2 = Nx.reshape(targets[text_index], {3,1}) |> IO.inspect()
target_probas_2 = Nx.take_along_axis(probas[text_index], target_2, axis: 1) |> Nx.reshape({3})

IO.puts("Text 1: #{inspect(Nx.to_flat_list(target_probas_1))}")
IO.puts("Text 2: #{inspect(Nx.to_flat_list(target_probas_2))}")
#Nx.Tensor<
  s64[3][1]
  [
    [5149],
    [11031],
    [499]
  ]
>
#Nx.Tensor<
  s64[3][1]
  [
    [2216],
    [1093],
    [18414]
  ]
>
Text 1: [1.9279423213447444e-5, 1.6736576071707532e-5, 2.2360376533470117e-5]
Text 2: [1.896309549920261e-5, 1.7693790141493082e-5, 2.41508059843909e-5]
:ok

The goal of training an LLM is to maximize the likelihood of the correct token, which involves increasing its probability relative to other tokens. This way, we ensure the LLM consistently picks the target token essentially the next word in the sentence as the next token it generates.

Backpropagation requires a loss function, which calculates the difference between the model’s predicted output (here, the probabilities corresponding to the target token IDs) and the actual desired output. This loss function measures how far off the model’s predictions are from the target values.

# Working with logarithms of probability scores is more manageable in mathematical
# optimization than handling the scores directly.

log_probas =
  [target_probas_1, target_probas_2]
  |> Nx.concatenate()
  |> Nx.log()
#Nx.Tensor<
  f32[6]
  EXLA.Backend
  [-10.85647201538086, -10.99791431427002, -10.708220481872559, -10.873015403747559, -10.942296981811523, -10.631193161010742]
>
avg_log_probas = Nx.mean(log_probas)
#Nx.Tensor<
  f32
  EXLA.Backend
  -10.834851264953613
>
neg_avg_log_probas = Nx.multiply(avg_log_probas, -1)
#Nx.Tensor<
  f32
  EXLA.Backend
  10.834851264953613
>

The goal is to get the average log probability as close to 0 as possible by updating the model’s weights as part of the training process.

However, the common practice isn’t to push the average log probability up to 0 but rather to bring the negative average log probability down to 0. The negative average log probability is simply the average log probability multiplied by –1.

In deep learning, the term for turning this negative value, is known as cross entropy loss.

The cross entropy loss is popular measure in machine learning and deep learning that measures the difference between two probability distribution.

IO.inspect(logits.shape, label: "Logits shape:")
IO.inspect(targets.shape, label: "Targets shape:")
Logits shape:: {2, 3, 50257}
Targets shape:: {2, 3}
{2, 3}
logits_flat = Nx.flatten(logits, axes: [0,1])
targets_flat = Nx.flatten(targets)
#Nx.Tensor<
  s64[6]
  [5149, 11031, 499, 2216, 1093, 18414]
>
IO.inspect(logits_flat.shape, label: "Flatten Logits shape:")
IO.inspect(targets_flat.shape, label: "Flatten Targets shape:")
Flatten Logits shape:: {6, 50257}
Flatten Targets shape:: {6}
{6}
y_true = Nx.tensor([0, 2, 1])
y_pred = Nx.tensor([[0.2, 0.8, 0.0], [0.1, 0.2, 0.7], [0.1, 0.2, 0.7]])
#Nx.Tensor<
  f32[3][3]
  [
    [0.20000000298023224, 0.800000011920929, 0.0],
    [0.10000000149011612, 0.20000000298023224, 0.699999988079071],
    [0.10000000149011612, 0.20000000298023224, 0.699999988079071]
  ]
>
loss = Axon.Losses.categorical_cross_entropy(
  targets_flat, logits_flat, reduction: :mean, from_logits: true, sparse: true)
#Nx.Tensor<
  f32
  EXLA.Backend
  10.834851264953613
>

Perplexity

Perplexity is a measure often used alongside cross entropy loss to evaluate the performance of models in tasks like language modeling. It can provide a more interpretable way to understand the uncertainty of a model in predicting the next token in a sequence.

Perplexity measures how well the probability distribution predicted by the model matches the actual distribution of the words in the dataset. Similar to the loss, a lower perplexity indicates that the model predictions are closer to the actual distribution.

Perplexity is often considered more interpretable than the raw loss value because it signifies the effective vocabulary size about which the model is uncertain at each step. In the given example, this would translate to the model being unsure about which among 48,725 tokens in the vocabulary to generate as the next token.

perplexity = Nx.exp(loss)
#Nx.Tensor<
  f32
  EXLA.Backend
  50759.359375
>

5.1.3 Calculating the training and validation set losses

path = 
  "/home/alde/Documents/MyDevelopment/Build_A_Large_Language_Model/the-verdict.txt"
{:ok, raw_text} = File.read(path)
{:ok,
 "I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and established himself in a villa on the Riviera. (Though I rather thought it would have been Rome or Florence.)\n\n\"The height of his glory\"--that was what the women called it. I can hear Mrs. Gideon Thwing--his last Chicago sitter--deploring his unaccountable abdication. \"Of course it's going to send the value of my picture 'way up; but I don't think of that, Mr. Rickham--the loss to Arrt is all I think of.\" The word, on Mrs. Thwing's lips, multiplied its _rs_ as though they were reflected in an endless vista of mirrors. And it was not only the Mrs. Thwings who mourned. Had not the exquisite Hermia Croft, at the last Grafton Gallery show, stopped me before Gisburn's \"Moon-dancers\" to say, with tears in her eyes: \"We shall not look upon its like again\"?\n\nWell!--even through the prism of Hermia's tears I felt able to face the fact with equanimity. Poor Jack Gisburn! The women had made him--it was fitting that they should mourn him. Among his own sex fewer regrets were heard, and in his own trade hardly a murmur. Professional jealousy? Perhaps. If it were, the honour of the craft was vindicated by little Claude Nutley, who, in all good faith, brought out in the Burlington a very handsome \"obituary\" on Jack--one of those showy articles stocked with random technicalities that I have heard (I won't say by whom) compared to Gisburn's painting. And so--his resolve being apparently irrevocable--the discussion gradually died out, and, as Mrs. Thwing had predicted, the price of \"Gisburns\" went up.\n\nIt was not till three years later that, in the course of a few weeks' idling on the Riviera, it suddenly occurred to me to wonder why Gisburn had given up his painting. On reflection, it really was a tempting problem. To accuse his wife would have been too easy--his fair sitters had been denied the solace of saying that Mrs. Gisburn had \"dragged him down.\" For Mrs. Gisburn--as such--had not existed till nearly a year after Jack's resolve had been taken. It might be that he had married her--since he liked his ease--because he didn't want to go on painting; but it would have been hard to prove that he had given up his painting because he had married her.\n\nOf course, if she had not dragged him down, she had equally, as Miss Croft contended, failed to \"lift him up\"--she had not led him back to the easel. To put the brush into his hand again--what a vocation for a wife! But Mrs. Gisburn appeared to have disdained it--and I felt it might be interesting to find out why.\n\nThe desultory life of the Riviera lends itself to such purely academic speculations; and having, on my way to Monte Carlo, caught a glimpse of Jack's balustraded terraces between the pines, I had myself borne thither the next day.\n\nI found the couple at tea beneath their palm-trees; and Mrs. Gisburn's welcome was so genial that, in the ensuing weeks, I claimed it frequently. It was not that my hostess was \"interesting\": on that point I could have given Miss Croft the fullest reassurance. It was just because she was _not_ interesting--if I may be pardoned the bull--that I found her so. For Jack, all his life, had been surrounded by interesting women: they had fostered his art, it had been reared in the hot-house of their adulation. And it was therefore instructive to note what effect the \"deadening atmosphere of mediocrity\" (I quote Miss Croft) was having on him.\n\nI have mentioned that Mrs. Gisburn was rich; and it was immediately perceptible that her husband was extracting from this circumstance a delicate but substantial satisfaction. It is, as a rule, the people who scorn money who get most out of it; and Jack's elegant disdain of his wife's big balance enabled him, with an appearance of perfect good-breeding, to transmute it into objects of art and luxury. To the latter, I must add, he remained relatively indifferent; but he was buying Renaissance bronzes and eight" <> ...}
{:ok, ids} = Tiktoken.encode(tokenizer, raw_text)
{:ok,
 [40, 473, 1846, 2744, 3463, 7762, 480, 285, 22464, 4856, 264, 12136, 35201, 313, 4636, 264, 1695,
  12637, 3403, 313, 708, 433, 574, 912, 2294, 13051, 311, 757, 311, 6865, 430, 11, 304, 279, 2673,
  315, 813, 27025, 11, 568, 1047, 12504, 813, 19354, 11, 12502, 264, 9257, ...]}
String.length(raw_text) |> IO.inspect(label: "Characters")
Enum.count(ids) |> IO.inspect(label: "Tokens")
Characters: 20479
Tokens: 4943
4943

When preparing the data loaders, we split the input text into training and validation set portions. Then we tokenize the text and divide the tokenized text into chunks of a user-specified length. Finally, we shuffle the rows and organize the chunked text into batches, which we can use for model training

train_ratio = 0.90

split_index = floor(train_ratio * String.length(raw_text))

{train_data, validation_data} = String.split_at(raw_text, split_index)

String.length(train_data) |> IO.inspect(label: "Train Characters")
String.length(validation_data) |> IO.inspect(label: "Validation Characters")
Train Characters: 18431
Validation Characters: 2048
2048

We split the input text into training and validation set portions. Then we tokenize the text and divide the tokenized text into chunks of a user-specified length. Finally, we shuffle the rows and organize the chunked text into batches, which we can use for model training.

However, in practice, it can also be beneficial to train an LLM with variable-length inputs to help the LLM to better generalize across different types of inputs when it is being used.

defmodule MyGPT.DatasetV1 do
  def build(txt, tokenizer_model, max_length, stride) do
    {:ok, token_ids} = Tiktoken.encode(tokenizer_model, txt)

    token_ids_tensor = Nx.tensor(token_ids)
    text_length = length(token_ids)

    linespace =
      Enum.to_list(0..(length(token_ids) - max_length - 1)) |> Enum.take_every(stride)

    for i <- linespace, reduce: %{input_ids: [], target_ids: []} do
      %{input_ids: input_ids, target_ids: target_ids} = acc ->

        
        {input_ids, target_ids} =
          cond do
            i + max_length > text_length - 1 ->
              {input_ids, target_ids}

            i + max_length + 1 > text_length - 1 ->
              {input_ids, target_ids}

            true ->
              input_chunk = token_ids_tensor[i..(i + max_length - 1)]
              target_chunk = token_ids_tensor[(i + 1)..(i + max_length)]
              {input_ids ++ [input_chunk], target_ids ++ [target_chunk]}
          end

        %{acc | input_ids: input_ids, target_ids: target_ids}
    end
  end
end
{:module, MyGPT.DatasetV1, <<70, 79, 82, 49, 0, 0, 12, ...>>, {:build, 4}}
context_length = gpt_config_124m[:context_length]
train_dataset = MyGPT.DatasetV1.build(train_data, tokenizer, context_length, context_length)
validation_dataset = MyGPT.DatasetV1.build(validation_data, tokenizer, context_length, context_length)
%{
  input_ids: [
    #Nx.Tensor<
      s64[256]
      [361, 6, 29368, 1093, 264, 3838, 315, 7563, 13, 1283, 3287, 956, 21423, 261, 11, 499, 3619, 11, 8009, 4610, 3023, 313, 383, 1120, 11203, 1070, 30666, 10307, 11, 323, 389, 813, 23726, 11, 1555, 279, 18004, 48788, 11, 358, 9508, 311, 6865, 279, 3488, 25, 364, 11787, ...]
    >,
    #Nx.Tensor<
      s64[256]
      [364, 10655, 6, 555, 1063, 832, 1501, 88, 0, 2468, 1176, 358, 574, 16984, 1364, 8434, 956, 1095, 757, 1022, 313, 438, 520, 856, 289, 1220, 6, 842, 358, 12090, 2895, 58863, 13, 7566, 11, 433, 574, 358, 889, 3940, 2895, 58863, 25, 358, 3309, 18083, 13, ...]
    >
  ],
  target_ids: [
    #Nx.Tensor<
      s64[256]
      [6, 29368, 1093, 264, 3838, 315, 7563, 13, 1283, 3287, 956, 21423, 261, 11, 499, 3619, 11, 8009, 4610, 3023, 313, 383, 1120, 11203, 1070, 30666, 10307, 11, 323, 389, 813, 23726, 11, 1555, 279, 18004, 48788, 11, 358, 9508, 311, 6865, 279, 3488, 25, 364, 11787, ...]
    >,
    #Nx.Tensor<
      s64[256]
      [10655, 6, 555, 1063, 832, 1501, 88, 0, 2468, 1176, 358, 574, 16984, 1364, 8434, 956, 1095, 757, 1022, 313, 438, 520, 856, 289, 1220, 6, 842, 358, 12090, 2895, 58863, 13, 7566, 11, 433, 574, 358, 889, 3940, 2895, 58863, 25, 358, 3309, 18083, 13, ...]
    >
  ]
}
input_ids = Nx.stack(train_dataset[:input_ids])
target_ids = Nx.stack(train_dataset[:target_ids])

{input_ids[0], target_ids[0]}
{#Nx.Tensor<
   s64[256]
   [40, 473, 1846, 2744, 3463, 7762, 480, 285, 22464, 4856, 264, 12136, 35201, 313, 4636, 264, 1695, 12637, 3403, 313, 708, 433, 574, 912, 2294, 13051, 311, 757, 311, 6865, 430, 11, 304, 279, 2673, 315, 813, 27025, 11, 568, 1047, 12504, 813, 19354, 11, 12502, 264, 9257, 57896, ...]
 >,
 #Nx.Tensor<
   s64[256]
   [473, 1846, 2744, 3463, 7762, 480, 285, 22464, 4856, 264, 12136, 35201, 313, 4636, 264, 1695, 12637, 3403, 313, 708, 433, 574, 912, 2294, 13051, 311, 757, 311, 6865, 430, 11, 304, 279, 2673, 315, 813, 27025, 11, 568, 1047, 12504, 813, 19354, 11, 12502, 264, 9257, 57896, ...]
 >}
v_input_ids = Nx.stack(validation_dataset[:input_ids])
v_target_ids = Nx.stack(validation_dataset[:target_ids])

{v_input_ids[0], v_target_ids[0]}
{#Nx.Tensor<
   s64[256]
   [361, 6, 29368, 1093, 264, 3838, 315, 7563, 13, 1283, 3287, 956, 21423, 261, 11, 499, 3619, 11, 8009, 4610, 3023, 313, 383, 1120, 11203, 1070, 30666, 10307, 11, 323, 389, 813, 23726, 11, 1555, 279, 18004, 48788, 11, 358, 9508, 311, 6865, 279, 3488, 25, 364, 11787, 499, ...]
 >,
 #Nx.Tensor<
   s64[256]
   [6, 29368, 1093, 264, 3838, 315, 7563, 13, 1283, 3287, 956, 21423, 261, 11, 499, 3619, 11, 8009, 4610, 3023, 313, 383, 1120, 11203, 1070, 30666, 10307, 11, 323, 389, 813, 23726, 11, 1555, 279, 18004, 48788, 11, 358, 9508, 311, 6865, 279, 3488, 25, 364, 11787, 499, ...]
 >}
stream = Nx.to_batched(input_ids, 1)
x = Enum.at(stream, 0)
#Nx.Tensor<
  s64[1][256]
  [
    [40, 473, 1846, 2744, 3463, 7762, 480, 285, 22464, 4856, 264, 12136, 35201, 313, 4636, 264, 1695, 12637, 3403, 313, 708, 433, 574, 912, 2294, 13051, 311, 757, 311, 6865, 430, 11, 304, 279, 2673, 315, 813, 27025, 11, 568, 1047, 12504, 813, 19354, 11, 12502, 264, 9257, 57896, 11, ...]
  ]
>
defmodule MyGPT.Dataset do
  def build(txt, tokenizer_model, max_length, stride, batch_size \\ 2, is_shuffle? \\ true) do
    {:ok, token_ids} = Tiktoken.encode(tokenizer_model, txt)

    token_ids_tensor = Nx.tensor(token_ids)
    text_length = length(token_ids)

    linespace =
      Enum.to_list(0..(length(token_ids) - max_length - 1)) |> Enum.take_every(stride)

    %{input_ids: input_ids, target_ids: target_ids} =
      for i <- linespace, reduce: %{input_ids: [], target_ids: []} do
        %{input_ids: input_ids, target_ids: target_ids} = acc ->
          {input_ids, target_ids} =
            cond do
              i + max_length > text_length - 1 ->
                {input_ids, target_ids}

              i + max_length + 1 > text_length - 1 ->
                {input_ids, target_ids}

              true ->
                input_chunk = token_ids_tensor[i..(i + max_length - 1)]
                target_chunk = token_ids_tensor[(i + 1)..(i + max_length)]
                {input_ids ++ [input_chunk], target_ids ++ [target_chunk]}
            end

          %{acc | input_ids: input_ids, target_ids: target_ids}
      end

    {shuffled_input_ids, shuffled_target_ids} =
      try_shuffle_datasets(input_ids, target_ids, is_shuffle?)

    input_ids_stream =
      shuffled_input_ids
      |> Nx.stack()
      |> Nx.to_batched(batch_size)

    target_ids_stream =
      shuffled_target_ids
      |> Nx.stack()
      |> Nx.to_batched(batch_size)

    Stream.zip(input_ids_stream, target_ids_stream)
  end

  def try_shuffle_datasets(input_ids, target_ids, true) do
    Enum.zip(input_ids, target_ids)
    |> Enum.shuffle()
    |> Enum.unzip()
  end

  def try_shuffle_datasets(input_ids, target_ids, _is_shuffle?) do
    {input_ids, target_ids}
  end
end
{:module, MyGPT.Dataset, <<70, 79, 82, 49, 0, 0, 17, ...>>, {:try_shuffle_datasets, 3}}
context_length = gpt_config_124m[:context_length]

training_dataset =
  MyGPT.Dataset.build(train_data, tokenizer, context_length, context_length)

validation_dataset =
  MyGPT.Dataset.build(validation_data, tokenizer, context_length, context_length)
#Function<73.53678557/2 in Stream.zip_with/2>
Enum.at(training_dataset, 0)
{#Nx.Tensor<
   s64[2][256]
   [
     [574, 11, 304, 2144, 11, 10671, 279, 893, 315, 279, 4545, 313, 300, 7762, 5678, 11, 832, 2643, 2231, 433, 11, 1047, 1027, 279, 893, 315, 279, 6596, 13, 578, 14992, 10255, 574, 1071, 311, 617, 14454, 5678, 520, 856, 4333, 596, 7693, 11, 323, 358, 31156, 422, 264, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   s64[2][256]
   [
     [11, 304, 2144, 11, 10671, 279, 893, 315, 279, 4545, 313, 300, 7762, 5678, 11, 832, 2643, 2231, 433, 11, 1047, 1027, 279, 893, 315, 279, 6596, 13, 578, 14992, 10255, 574, 1071, 311, 617, 14454, 5678, 520, 856, 4333, 596, 7693, 11, 323, 358, 31156, 422, 264, ...],
     ...
   ]
 >}

We used a relatively small batch size to reduce the computational resource demand because we were working with a very small dataset. In practice, training LLMs with batch sizes of 1,024 or larger is not uncommon.

5.2 Training an LLM

A typical training loop for training deep neural networks consists of numerous steps, iterating over the batches in the training set for several epochs. In each loop, we calculate the loss for each training set batch to determine loss gradients, which we use to update the model weights so that the training set loss is minimized.

It outlines eight steps, starting with iterating over each epoch, processing batches, resetting gradients, calculating the loss and new gradients, and updating weights and concluding with monitoring steps like printing losses and generating text samples.

More advanced techniques, including learning rate warmup, cosine annealing, and gradient clipping.

Adam optimizers are a popular choice for training deep neural networks. However, in our training loop, we opt for the AdamW optimizer. AdamW is a variant of Adam that improves the weight decay approach, which aims to minimize model complexity and prevent overfitting by penalizing larger weights. This adjustment allows AdamW to achieve more effective regularization and better generalization; thus, AdamW is frequently used in the training of LLMs

model = MyGPT.model({nil, nil, 768}, gpt_config_124m)
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
template = Nx.template({1, 4}, :s64)
params = init_fn.(template, %{})
%{
  "pos_embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[256][768]
      EXLA.Backend
      [
        [0.00835605338215828, -0.001519892131909728, -0.008858468383550644, 0.005154099315404892, 0.003255062038078904, 0.008193318732082844, -0.009548058733344078, -0.0023421000223606825, -0.009752254001796246, -0.005784115754067898, 0.004367899615317583, -0.003659622510895133, -0.008940774947404861, -0.0011999272974207997, 0.0022269010078161955, 0.003976454492658377, 0.002604701556265354, 0.0012910985387861729, -0.003927278332412243, -0.007751474156975746, -0.005583465099334717, 9.162902715615928e-4, -0.006056143902242184, 0.005630659870803356, -0.0012507104547694325, -0.006014976184815168, -0.004163548815995455, 0.008944415487349033, -0.006644425448030233, 0.004301333334296942, 0.004310044925659895, 0.007871363312005997, 0.00951245054602623, 0.008124735206365585, 0.0035520195960998535, 0.002551331417635083, 0.008987133391201496, 0.009850654751062393, 0.009532613679766655, -0.0049337721429765224, -0.0033276937901973724, -9.341287659481168e-4, -0.0026423835661262274, -0.001537594711408019, 0.001408622250892222, -0.009654521942138672, -0.0013557791244238615, 0.009470796212553978, ...],
        ...
      ]
    >
  },
  "normalization_6" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_4" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.030560854822397232, -0.027250023558735847, -0.005801537539809942, 0.037258580327034, 0.010854002088308334, -0.021780598908662796, -0.001860691118054092, -0.013453823514282703, 0.03244549408555031, 0.008650727570056915, -0.014005966484546661, 0.037117090076208115, -0.03246651962399483, 0.006978277582675219, 0.03205595538020134, -0.02779693529009819, 0.02460232563316822, -0.006824481766670942, 0.01477330457419157, -0.02904941886663437, 0.002669166075065732, 0.01315983198583126, -0.014481235295534134, -0.03453880548477173, -0.011130832135677338, -0.0036918087862432003, 0.003542018588632345, 0.028571907430887222, -0.007932443171739578, 0.011332305148243904, -0.017095070332288742, -0.023716920986771584, 0.03030044212937355, 0.0022417071741074324, -0.006023951806128025, -0.0026704478077590466, 0.02095763012766838, -0.004748350940644741, 0.028788648545742035, 0.03126901760697365, 0.026277264580130577, 0.01788645051419735, 0.010327926836907864, -0.01670343428850174, -0.016724515706300735, ...],
        ...
      ]
    >
  },
  "normalization_5" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_20" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.001402650261297822, 0.026030158624053, 0.014029084704816341, -0.022986186668276787, -0.01998908258974552, 0.021588899195194244, -0.018844565376639366, -0.001816170639358461, 0.029988769441843033, -0.026286019012331963, -0.007959961891174316, 0.023773569613695145, 0.009304172359406948, -0.0050235409289598465, 0.032488297671079636, 0.007145059760659933, -0.01527420710772276, 0.03734298422932625, -0.023597825318574905, -0.03240213170647621, 0.006338563747704029, -0.03135897219181061, 0.032794684171676636, 0.01089505385607481, 0.034490156918764114, -0.02628091163933277, -0.01650899089872837, 0.021328486502170563, 0.009490321390330791, 0.014418328180909157, 0.0035156020894646645, -0.0059839640744030476, -0.03701825812458992, -0.029154622927308083, -0.0335594080388546, 0.033675819635391235, -0.02643868327140808, 0.025859285145998, 0.011411563493311405, -0.007240113336592913, -0.02220505103468895, -0.03930899128317833, -0.001734273275360465, ...],
        ...
      ]
    >
  },
  "normalization_1" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_9" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_22" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.032839447259902954, 0.01748879998922348, -0.009435717016458511, -0.0038825287483632565, 0.03255155310034752, -0.007750996388494968, -0.037842925637960434, -0.024921366944909096, 0.035743165761232376, -0.01913507841527462, -0.024204157292842865, -0.003818999510258436, -0.03740227222442627, -0.0053386809304356575, 0.02387901023030281, -0.018220098689198494, -0.026866378262639046, -0.016436658799648285, 0.025900205597281456, 0.03641452640295029, 0.023092973977327347, -0.012555214576423168, 0.02324499748647213, -0.024318654090166092, -0.022717932239174843, 0.026522766798734665, 0.03162361681461334, -0.039043743163347244, -0.023365072906017303, -0.010126953013241291, 0.03757007047533989, -0.0015227820258587599, -0.03565472736954689, -0.01314252894371748, -0.025106405839323997, -0.014796290546655655, 0.020547039806842804, -0.013771630823612213, -0.0367637537419796, -0.03581513091921806, ...],
        ...
      ]
    >
  },
  "causal_attention_0" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.057891085743904114, 0.031700730323791504, -0.06245023012161255, 0.011017337441444397, 0.036618709564208984, 0.005764394998550415, 0.008794963359832764, 0.06204204261302948, 0.0069023966789245605, 0.05000057816505432, 0.04367762804031372, -0.0025152266025543213, 0.024907901883125305, -0.0412302166223526, 0.006332278251647949, 0.044837549328804016, 0.04584498703479767, -0.04577548801898956, -0.04241083562374115, -0.05188211798667908, 0.047340184450149536, -0.013208165764808655, 0.032210513949394226, 0.007262393832206726, 0.018140941858291626, 0.0046993643045425415, 0.011359810829162598, -0.016532212495803833, -0.024260997772216797, 0.044848084449768066, 0.04419344663619995, -0.0027293264865875244, 0.02463291585445404, 0.04433523118495941, 0.058565422892570496, -0.00868140161037445, 0.04050922393798828, -0.0424494594335556, 0.027760088443756104, 0.03202183544635773, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.017343029379844666, -0.05196136236190796, -0.01847122609615326, 0.002339109778404236, -0.04784758388996124, -0.04393137991428375, 0.06190197169780731, -0.047495126724243164, -0.010552316904067993, -0.05884957313537598, 0.04056979715824127, -0.03719775378704071, 0.010772719979286194, 0.04936544597148895, -0.034493789076805115, -0.017754733562469482, 0.030666843056678772, -0.038368821144104004, 0.023557156324386597, -0.049014195799827576, 0.03288586437702179, -0.0314159095287323, 0.019574686884880066, 4.973858594894409e-4, -0.03216104209423065, -9.765028953552246e-4, 0.030408024787902832, 0.03191480040550232, 6.589740514755249e-4, -0.013176977634429932, -0.009991958737373352, -0.011034354567527771, -0.019491925835609436, 0.02053055167198181, 0.03900234401226044, -0.018408194184303284, -0.02395184338092804, -0.027811110019683838, 0.04266554117202759, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.010971829295158386, 0.021520555019378662, 0.030278638005256653, -0.05376383662223816, 0.05611552298069, 0.05886222422122955, 0.042443737387657166, 0.02341116964817047, -0.0358569473028183, 0.022672876715660095, 0.053060173988342285, 0.043153032660484314, -0.010947123169898987, 0.011890217661857605, -0.006963372230529785, -0.02050580084323883, -0.033719345927238464, -0.053895220160484314, -0.053845107555389404, 0.050527408719062805, -0.05709512531757355, -0.034408554434776306, -0.031364426016807556, -0.03923290967941284, 0.023456141352653503, -0.012669429183006287, 0.014847680926322937, 0.052733391523361206, -0.04924841225147247, 0.030778929591178894, -0.03106127679347992, 0.023757562041282654, 0.05307476222515106, 0.040725722908973694, -0.006383955478668213, -0.046185776591300964, -0.006404787302017212, -0.03518342971801758, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04545803368091583, 0.030514270067214966, 0.015245422720909119, -0.018513858318328857, -0.02845446765422821, 0.0065723806619644165, 0.004867807030677795, -0.027993708848953247, -0.03735765814781189, -0.01571004092693329, -0.048844486474990845, 0.039162665605545044, 0.057723283767700195, 0.04536963999271393, -1.0700523853302002e-4, -0.04877130687236786, -0.009924352169036865, 0.058864712715148926, 0.04070502519607544, 0.00615084171295166, -0.05366234481334686, -0.014606103301048279, -0.006354272365570068, 0.02052880823612213, -0.037248238921165466, -0.059652313590049744, -0.046579405665397644, -0.04805208742618561, 0.0341370552778244, 0.02092897891998291, -0.04758661985397339, 0.055146828293800354, -0.05920557677745819, -0.004037842154502869, -0.015403002500534058, 0.021351873874664307, -0.048443153500556946, ...],
        ...
      ]
    >
  },
  "normalization_14" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_7" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_11" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_2" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_12" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.014665423892438412, -0.03423280641436577, 0.007403323892503977, 0.014595966786146164, 0.0350615419447422, -0.02258511632680893, 0.018806103616952896, -0.03908606991171837, 0.0012804921716451645, -0.004309384617954493, 0.003954596351832151, -0.022425411269068718, -0.013981010764837265, -0.037549443542957306, -0.0170531515032053, 0.025982018560171127, 0.03622934967279434, -0.0034016433637589216, -0.02497323974967003, 0.02541649341583252, -0.005502334330230951, -0.03843020275235176, -0.010772820562124252, -0.010291784070432186, 0.022237934172153473, -0.03743870556354523, 0.03818579018115997, -0.0012284604599699378, 0.024996357038617134, -0.026888195425271988, 0.019582366570830345, 0.03952116146683693, -0.015162509866058826, -0.006583756301552057, ...],
        ...
      ]
    >
  },
  "dense_23" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.022730400785803795, 0.01977645978331566, -0.010936511680483818, -0.00504173943772912, 0.01309595350176096, -0.01170544233173132, 0.024080123752355576, -0.02961093932390213, -0.034172262996435165, -0.03458667919039726, -0.015769068151712418, -0.0029777938034385443, -0.03279142081737518, -0.028085971251130104, 0.022212449461221695, 0.0179872065782547, -0.004728908184915781, -0.03420347720384598, 0.015261314809322357, 0.03054223209619522, -0.006784164812415838, 0.03007461503148079, 0.016574302688241005, 0.02377692610025406, 8.609212818555534e-4, -0.027900047600269318, -0.003086107550188899, 0.03739558160305023, -0.012824345380067825, 0.02526100166141987, 0.017382541671395302, -0.029153350740671158, -0.019281957298517227, ...],
        ...
      ]
    >
  },
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[50257][768]
      EXLA.Backend
      [
        [-0.0023072862531989813, 1.9459724717307836e-4, -0.006358785554766655, 0.008084396831691265, 0.004301616922020912, -0.006599075626581907, -0.003064620541408658, 0.009713186882436275, 0.007698776666074991, 0.004165008198469877, 0.009833386167883873, -0.006449816282838583, 0.004132912028580904, -0.00903132651001215, -2.103328733937815e-5, -0.008257054723799229, 0.009116568602621555, -0.009427277371287346, -0.005403261166065931, 0.008155128918588161, 0.003083169460296631, 0.002735474146902561, 0.006756236311048269, -0.003080449067056179, -0.00871496181935072, 0.008054287172853947, 0.005646452773362398, 2.3839234199840575e-4, -0.004616654012352228, 0.0012502741301432252, -0.005863924045115709, -9.558605961501598e-4, -7.085752440616488e-4, ...],
        ...
      ]
    >
  },
  "normalization_19" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_6" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.04372021555900574, -0.05461260676383972, -0.0163973867893219, 0.04845161736011505, 0.05066724121570587, 0.008870065212249756, -0.01659385859966278, 0.004806995391845703, -0.02691030502319336, -0.05004522204399109, 0.005778223276138306, 0.03468617796897888, 0.004931449890136719, 0.011309951543807983, -0.013383910059928894, 0.03893502056598663, -0.040354788303375244, -0.007732570171356201, -0.056161344051361084, -0.005621924996376038, 0.02063313126564026, 0.018917188048362732, -0.004728302359580994, 0.03797869384288788, -0.03348761796951294, 0.015197128057479858, 0.043700337409973145, -0.051165685057640076, 0.05622991919517517, -0.059008076786994934, 0.053691089153289795, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.059276074171066284, -0.03299407660961151, -0.0012415945529937744, -0.007177069783210754, 0.04831878840923309, 0.05697666108608246, 0.029269084334373474, -0.01404900848865509, 0.04893957078456879, -0.06023797392845154, 0.011406168341636658, 0.04313458502292633, 0.048154667019844055, 0.0016535371541976929, 8.409619331359863e-4, -0.03205558657646179, 0.001989513635635376, -0.05353190004825592, -0.0451153963804245, 0.011023402214050293, -0.01085333526134491, 0.039252832531929016, -0.008883729577064514, -0.02736535668373108, -0.011844038963317871, 0.015324264764785767, 0.058133602142333984, 0.03108358383178711, 0.055795446038246155, 0.045146599411964417, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.026181727647781372, -0.03783077001571655, -0.012845233082771301, 0.013555347919464111, -0.0010508596897125244, 0.03903794288635254, -0.020000562071800232, 0.030419468879699707, 0.022168219089508057, 0.012794554233551025, 0.023246660828590393, -0.026461124420166016, 0.060703083872795105, 0.02328266203403473, 0.02249416708946228, 0.028114140033721924, -0.02454233169555664, -0.047661662101745605, 0.0225488543510437, 0.03679265081882477, 0.04346102476119995, -0.02994680404663086, 0.014844462275505066, -0.007950350642204285, -0.02710092067718506, 0.022752895951271057, 0.04393559694290161, 0.0030765533447265625, 0.032603055238723755, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0037907958030700684, -0.03822736442089081, -0.012732520699501038, 0.06033077836036682, -0.0144425630569458, 0.033779218792915344, 0.05273233354091644, -0.0037374943494796753, 0.042464181780815125, 0.02314344048500061, -0.004419773817062378, 0.021412312984466553, 0.04784923791885376, 0.02112235128879547, 0.014909401535987854, 0.02004323899745941, -0.058507248759269714, 0.0091877281665802, 0.05800172686576843, -0.05034913122653961, 0.028179079294204712, -0.04382990300655365, 0.05480410158634186, 0.03657218813896179, 0.03160884976387024, -0.028384283185005188, 0.009273573756217957, -9.101629257202148e-5, ...],
        ...
      ]
    >
  },
  "causal_attention_7" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.006883874535560608, -0.03828051686286926, 0.054902151226997375, 0.018830999732017517, -0.013112828135490417, -0.01686742901802063, 0.02670101821422577, -0.03444002568721771, 0.06172819435596466, 0.05303303897380829, -0.05448821187019348, 0.04233023524284363, 0.045560434460639954, -0.05396512150764465, 0.015350520610809326, 0.04364544153213501, -2.1423399448394775e-4, 0.057442933320999146, 0.003253921866416931, 4.1791796684265137e-4, -0.006593614816665649, 0.04315215349197388, 0.046812936663627625, 0.03249366581439972, 0.021919190883636475, 0.0179244726896286, 0.04509909451007843, 0.007791832089424133, 6.400793790817261e-4, -0.05561529099941254, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05954357981681824, 0.048682764172554016, 0.005183219909667969, 0.04562261700630188, 0.00500151515007019, 0.008882462978363037, -0.028391718864440918, -0.03354524075984955, 0.061869606375694275, -0.008014991879463196, -0.06026732921600342, 0.04004766047000885, -0.020109444856643677, -0.050216346979141235, 0.019568532705307007, 0.02516806125640869, -0.040182650089263916, -0.03924292325973511, -0.05208887159824371, -0.05112527310848236, 0.057812660932540894, -0.046635642647743225, 0.03268757462501526, -0.01102900505065918, 0.02566593885421753, -0.05000424385070801, 0.045891642570495605, -0.022050216794013977, -0.057632461190223694, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.0015802383422851562, -0.015831753611564636, 0.06140318512916565, -0.03871496021747589, 0.02639557421207428, 0.05958820879459381, 0.030300244688987732, -0.022506386041641235, 0.002844870090484619, -0.028374388813972473, -0.02086128294467926, 0.006628826260566711, -0.059636637568473816, -0.0361771285533905, 0.04415379464626312, 0.05345416069030762, -0.010674983263015747, 0.02986505627632141, -0.04020111262798309, -0.02574719488620758, -0.053329259157180786, 0.0365753173828125, 0.0012433677911758423, -0.04737547039985657, -0.05334906280040741, 0.024996846914291382, 0.05564171075820923, 0.03355228900909424, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.06018565595149994, -0.021540343761444092, -0.02803121507167816, -0.023015543818473816, -0.028129562735557556, -0.0021990984678268433, 0.03805480897426605, -0.00612013041973114, -0.05022038519382477, -0.010736629366874695, 0.03362873196601868, -0.03409759700298309, -0.028146818280220032, 0.06124769151210785, 0.013368412852287292, 0.005100950598716736, 0.020238086581230164, 0.032200589776039124, -0.03428545594215393, -0.019676849246025085, 0.02294372022151947, -0.013442158699035645, -0.05963499844074249, -0.007126361131668091, -0.02534639835357666, -0.05563384294509888, 0.03587965667247772, ...],
        ...
      ]
    >
  },
  "causal_attention_1" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.01717303693294525, 0.019940540194511414, 0.050296127796173096, -0.002328544855117798, 0.062323883175849915, 0.05713982880115509, -0.025680601596832275, 0.032320618629455566, -0.060504987835884094, 0.0422329306602478, -0.01661553978919983, 0.03390580415725708, -0.03289118409156799, 0.020389512181282043, 0.05127638578414917, 0.006738871335983276, -0.04529847204685211, 0.0054217129945755005, 0.044831275939941406, -0.009164422750473022, -0.005577683448791504, -0.0045256465673446655, 0.0026355981826782227, -0.0575108528137207, -0.014784306287765503, -0.04753446578979492, -0.015454992651939392, -0.020032942295074463, -0.060292914509773254, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05628634989261627, -9.068399667739868e-4, -0.04147881269454956, 0.036274418234825134, 0.0496169775724411, -0.04580633342266083, 0.012780413031578064, 0.0485074520111084, -0.016994193196296692, -0.036198198795318604, 0.039924561977386475, -0.04549674689769745, 0.01388666033744812, 0.023099541664123535, 0.0017266273498535156, -0.018895357847213745, 0.00807228684425354, -0.020003706216812134, 0.02583657205104828, 0.02245981991291046, -0.03314535319805145, 0.028385132551193237, -0.004530474543571472, 0.0016636103391647339, 0.05638307332992554, -0.025832220911979675, -0.05743655562400818, -0.0198851078748703, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.028159350156784058, -0.02944427728652954, 0.024292007088661194, 0.044402286410331726, 0.028572797775268555, 0.031183436512947083, 0.0137052983045578, 0.02138739824295044, 0.062221676111221313, -0.024520069360733032, 0.04670412838459015, -0.03341507911682129, -0.04777109622955322, 0.0180971622467041, 0.010214030742645264, -0.04504646360874176, -0.01538500189781189, -0.01290634274482727, 0.043160662055015564, 0.05815134942531586, -0.03161633014678955, -0.04230612516403198, -0.016028136014938354, -0.04547730088233948, 0.007249504327774048, 0.026490405201911926, -0.054174184799194336, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.03876812756061554, 0.04680551588535309, -0.05029653012752533, -0.03658084571361542, -0.04938751459121704, -0.01373186707496643, -0.04721260070800781, -0.027223870158195496, -0.03169974684715271, 0.03572802245616913, -0.049826234579086304, -0.03926685452461243, -0.04355897009372711, -0.018307924270629883, -0.0032237619161605835, 0.059096112847328186, 0.003275945782661438, -0.0449407696723938, -0.004061847925186157, 0.03529302775859833, 0.010559722781181335, -0.05994425714015961, -0.05891597270965576, 0.03503052890300751, -0.011939585208892822, -0.03909049928188324, ...],
        ...
      ]
    >
  },
  "dense_15" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.018354622647166252, -0.02974904328584671, 0.009677497670054436, -0.018656643107533455, 0.03757061809301376, -0.026423633098602295, -0.032439377158880234, 0.010218679904937744, 0.012282201088964939, -0.029823051765561104, -0.015523687936365604, -0.015777381137013435, -0.013000061735510826, 0.004265165887773037, 0.02302444912493229, -0.007102932780981064, -0.03913697972893715, 0.020981576293706894, -6.801345152780414e-4, 0.015887051820755005, 0.0015246951952576637, -0.02159043587744236, 0.0035735052078962326, -0.01733158342540264, -0.02059157006442547, 0.0028385117184370756, -0.014677006751298904, ...],
        ...
      ]
    >
  },
  "dense_9" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.029009317979216576, 0.022388694807887077, 0.012637460604310036, 0.011814161203801632, -4.470804415177554e-4, 0.013693352229893208, 0.010374859906733036, -0.03439570590853691, 0.013014198280870914, -0.0293971486389637, -0.029309295117855072, -0.029889587312936783, 0.014913519844412804, 0.003048758953809738, 0.017659852281212807, 0.012693281285464764, -3.2245321199297905e-4, -0.030619313940405846, 0.03947465121746063, -0.03255375847220421, -0.014647564850747585, -0.00802220031619072, 0.03477989882230759, 0.005550398491322994, -0.02827155403792858, -0.004791891202330589, ...],
        ...
      ]
    >
  },
  "normalization_8" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_10" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.031481996178627014, -0.05817766487598419, 0.013924255967140198, -0.008946642279624939, 0.019340798258781433, 0.03360477089881897, 0.007316336035728455, 0.04098181426525116, -7.37607479095459e-6, -0.017797216773033142, 0.03247946500778198, 0.02823394536972046, -0.05737721920013428, 0.0271918922662735, -0.05776165425777435, -0.016659602522850037, -0.019387677311897278, -0.04518856108188629, 0.03459125757217407, -0.011769741773605347, -0.049047186970710754, -0.020735889673233032, -0.027457118034362793, 0.010530591011047363, -0.05739058554172516, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05456387996673584, -0.011003509163856506, 0.028843805193901062, -0.017997249960899353, 0.008038237690925598, 0.03413446247577667, -0.052613720297813416, 0.04939882457256317, -0.052619799971580505, 0.0501551479101181, -0.0322205126285553, 0.010420143604278564, -0.033878907561302185, -0.023393258452415466, -0.04723133146762848, 0.017873749136924744, 0.04523386061191559, 0.021752387285232544, 0.03423243761062622, -0.0239860862493515, 0.010129496455192566, 0.0015448331832885742, 0.03317761421203613, 0.04482752084732056, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.03794516623020172, -0.04756000638008118, 0.0401889830827713, -0.05420325696468353, 0.05969052016735077, -0.05435197055339813, -0.0523114949464798, 0.011579513549804688, -0.055739328265190125, -0.02140970528125763, -0.008200213313102722, 0.015812620520591736, -0.025843873620033264, 0.060357511043548584, -0.010617494583129883, 0.010230779647827148, -0.06197206676006317, 0.01116025447845459, 0.02956978976726532, 0.036292001605033875, 0.017574280500411987, 0.061848029494285583, -0.04460093379020691, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.011754050850868225, 0.03260748088359833, -0.020954996347427368, 0.05471724271774292, 0.016808584332466125, 0.04684567451477051, -0.0036492496728897095, -0.027954503893852234, 0.042268410325050354, 0.003966853022575378, 0.00828571617603302, 0.05858337879180908, 0.037621259689331055, -0.02492007613182068, -0.04651986062526703, -0.049745768308639526, 0.05255410075187683, 0.023078158497810364, 0.02676534652709961, -0.03187434375286102, -0.058754220604896545, 0.0401504784822464, ...],
        ...
      ]
    >
  },
  "causal_attention_8" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0038306117057800293, -0.04274286329746246, -0.03369738161563873, 0.018675312399864197, -0.042270272970199585, -0.04810948669910431, 3.272444009780884e-4, -0.032044440507888794, -0.019368886947631836, 0.053090378642082214, -0.016367316246032715, 0.019252508878707886, -0.03844578564167023, 0.06042340397834778, 0.023747220635414124, 0.022929087281227112, -0.03250570595264435, 0.059052973985672, 0.002971351146697998, -0.03615255653858185, -0.014015644788742065, -0.04076719284057617, -0.06039714813232422, 0.004087761044502258, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.035819053649902344, 0.022814035415649414, -0.051139041781425476, 2.415776252746582e-4, 0.00493311882019043, -0.011880546808242798, 0.05178681015968323, -0.01603221893310547, -0.01905737817287445, -0.03600709140300751, 0.028011441230773926, -0.05856795608997345, -0.005464985966682434, -0.03435571491718292, 0.044359028339385986, -0.054839253425598145, -0.025955215096473694, -0.01923219859600067, 0.0029056817293167114, -0.016191869974136353, -0.034280672669410706, -0.030457884073257446, -0.053978994488716125, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05873514711856842, -0.051314592361450195, 0.02488495409488678, -0.06098216772079468, 0.00998447835445404, -0.05729646980762482, -0.04548957943916321, 0.03115116059780121, -0.058935925364494324, 0.02110879123210907, 0.040440648794174194, -0.01260027289390564, 0.003024384379386902, 0.003998488187789917, -0.038125500082969666, 0.024170801043510437, -0.009963542222976685, -0.007989495992660522, -0.014721214771270752, -0.03465455770492554, -0.04953807592391968, 0.02060265839099884, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.04030069708824158, -0.004945620894432068, 0.05670471489429474, -0.05550193786621094, 0.028519079089164734, 0.05478107929229736, 0.03855295479297638, -0.014860659837722778, 0.02837066352367401, -0.042844027280807495, -0.0208376944065094, -0.0536067932844162, -0.017931967973709106, 0.024702727794647217, -0.056287601590156555, -0.029698893427848816, -0.007974222302436829, 5.110204219818115e-4, 0.03232190012931824, 0.04638688266277313, -0.05506028234958649, ...],
        ...
      ]
    >
  },
  "normalization_12" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.030151838436722755, 0.026962148025631905, 0.008366762660443783, -0.017452403903007507, 0.033399298787117004, -0.035152025520801544, 0.02672167681157589, -0.01896383799612522, -0.0033247785177081823, -0.020549867302179337, 0.02164139226078987, -0.011851180344820023, -0.031905800104141235, -0.03789837658405304, -0.025157371535897255, 0.0357009656727314, 0.03845556080341339, -0.012642530724406242, -0.03029562532901764, -0.039294127374887466, 0.00793097261339426, ...],
        ...
      ]
    >
  },
  "dense_18" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.018258342519402504, 0.03370407223701477, 0.02826172485947609, -0.02285819500684738, 0.0030094594694674015, 0.017309220507740974, -0.012665912508964539, -0.03530336171388626, -0.012585730291903019, 0.005174923688173294, -0.01619606651365757, -0.03694703057408333, 0.01248558796942234, -0.022054847329854965, 0.0075738755986094475, -9.059601143235341e-5, 0.002120482036843896, 0.035016681998968124, 0.019131338223814964, -0.02836589142680168, ...],
        ...
      ]
    >
  },
  "dense_5" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.029081424698233604, 0.03904595971107483, 0.01739119179546833, 0.03495211899280548, 0.0044847335666418076, -0.021882692351937294, -0.016618069261312485, 0.03275499492883682, 0.008281228132545948, 0.0011905370047315955, 0.0050906892865896225, -2.5951757561415434e-4, 0.011774004437029362, -0.008875139057636261, -0.016247795894742012, 0.022003861144185066, -0.006579072680324316, -0.03853345289826393, 0.02050710842013359, ...],
        ...
      ]
    >
  },
  "dense_11" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-9.236778714694083e-5, -0.013915465213358402, -0.007801218889653683, 0.02294626459479332, -0.039117779582738876, -9.01059465832077e-5, 0.023736126720905304, -0.00909686554223299, 0.026170533150434494, 0.009422400034964085, -0.0059239971451461315, 0.018354225903749466, -0.0016467025270685554, -0.033300627022981644, -0.0017148685874417424, -0.014968481846153736, 0.012550266459584236, 0.03300395980477333, ...],
        ...
      ]
    >
  },
  "normalization_18" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_6" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.01332168560475111, -0.029901396483182907, 0.037792909890413284, -0.021031441166996956, -0.0010616688523441553, -0.018005112186074257, -0.028694339096546173, 0.022041359916329384, 0.02046753652393818, -0.025423353537917137, -0.02736477367579937, 0.018057189881801605, -0.00646343594416976, 0.005720262415707111, -0.009502874687314034, 0.010135331191122532, ...],
        ...
      ]
    >
  },
  "normalization_15" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "normalization_23" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "dense_10" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.016581106930971146, -0.017753632739186287, 0.003387054428458214, 0.01854347623884678, 0.012254512868821621, -0.025735346600413322, 0.029695946723222733, -0.025331372395157814, 0.0064471131190657616, -0.034072838723659515, 0.0019098295597359538, 0.001413667225278914, -0.005184828769415617, ...],
        ...
      ]
    >
  },
  "dense_13" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.03887765109539032, 0.019599169492721558, 0.030299291014671326, -0.03901126608252525, -0.025052526965737343, 0.003722032532095909, -0.03535781428217888, 0.03228810802102089, -0.021121349185705185, -0.007583139929920435, 0.005018583964556456, 0.0168507918715477, ...],
        ...
      ]
    >
  },
  "dense_21" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.03154725953936577, -0.03784885257482529, 8.361259097000584e-5, -0.0012775894720107317, 0.021928664296865463, 0.016265882179141045, 0.023632798343896866, 0.010407665744423866, 0.035026662051677704, -0.001064665731973946, 0.020116452127695084, ...],
        ...
      ]
    >
  },
  "dense_8" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.0033618255984038115, 0.006313174497336149, 0.034277912229299545, -0.035671692341566086, 0.02992657758295536, 0.012056337669491768, 0.009572981856763363, 0.03431883081793785, 0.0192117840051651, -0.03887660428881645, ...],
        ...
      ]
    >
  },
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.02323172800242901, 0.03722824156284332, -0.003946613986045122, -0.01957433670759201, 0.008806322701275349, -0.03601069375872612, -0.01720116101205349, -0.030252018943428993, 0.034761831164360046, ...],
        ...
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.030193248763680458, -0.012681170366704464, -0.010071669705212116, -0.00632505863904953, 0.013524412177503109, -0.0316292978823185, 0.033702149987220764, 0.0054716202430427074, ...],
        ...
      ]
    >
  },
  "causal_attention_3" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.010649889707565308, 0.03993590176105499, -0.03653435409069061, 0.03526468575000763, -0.011710599064826965, 0.03497679531574249, 0.026175886392593384, 0.0170925110578537, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.030536174774169922, -0.017163142561912537, -0.02458237111568451, -0.04745492339134216, 0.02047748863697052, 0.033374056220054626, -0.02810007333755493, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.055831655859947205, 0.0029443055391311646, 0.01922956109046936, -0.036626532673835754, -0.018797963857650757, 0.06059657037258148, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05943429470062256, 0.05346691608428955, 0.01129239797592163, 0.045209914445877075, -0.04347437620162964, ...],
        ...
      ]
    >
  },
  "normalization_21" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
    >
  },
  "causal_attention_5" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.048142120242118835, -0.028445109724998474, -0.00396360456943512, -0.024130433797836304, -0.059604644775390625, -0.015126034617424011, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.00908099114894867, 0.022494390606880188, 0.0017237812280654907, -0.050312504172325134, -0.02526646852493286, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.023866787552833557, 0.040734246373176575, 0.061486825346946716, -0.061327025294303894, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.015503287315368652, -0.022227853536605835, 0.004416540265083313, ...],
        ...
      ]
    >
  },
  "dense_24" => %{
    "kernel" => #Nx.Tensor<
      f32[768][50257]
      EXLA.Backend
      [
        [0.007333068177103996, 0.009198933839797974, 0.001429903320968151, -0.0014841653173789382, 0.008069896139204502, ...],
        ...
      ]
    >
  },
  "dense_14" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0, 0.0, 0.0, 0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.009807120077311993, -0.002012290759012103, -0.009895576164126396, ...],
        ...
      ]
    >
  },
  "causal_attention_4" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05512331426143646, -0.019126906991004944, -0.03597761690616608, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05108453333377838, -0.040571004152297974, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.03417539596557617, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_10" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0, 1.0, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >
  },
  "dense_7" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        ...
      ]
    >
  },
  "normalization_22" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [...]
    >,
    ...
  },
  "causal_attention_2" => %{...},
  ...
}
loss_fn = fn y_true, y_pred ->
  logits_flat = Nx.flatten(y_pred, axes: [0, 1])
  targets_flat = Nx.flatten(y_true)

  Axon.Losses.categorical_cross_entropy(targets_flat, logits_flat,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )
end

optimizer = Polaris.Optimizers.adam(learning_rate: 0.004)

trained_model_state =
  model
  |> Axon.Loop.trainer(loss_fn, optimizer)
  |> Axon.Loop.validate(model, validation_dataset)
  |> Axon.Loop.run(training_dataset, %{}, epochs: 100, compiler: EXLA)

18:34:17.327 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 0, loss: 0.0000000
Batch: 0, loss: 7.3697920
Batch: 0, loss: 7.6583509
Batch: 0, loss: 7.5588903
Batch: 0, loss: 7.5619526
Epoch: 5, Batch: 5, loss: 7.3297234
Batch: 0, loss: 7.4858208
Batch: 0, loss: 7.4663801
Batch: 0, loss: 7.4930587
Batch: 0, loss: 7.4903049
Batch: 0, loss: 7.4499235
Epoch: 11, Batch: 1, loss: 6.9059758
Batch: 0, loss: 7.5709586
Batch: 0, loss: 7.4911027
Batch: 0, loss: 7.5220985
Batch: 0, loss: 7.5937834
Epoch: 16, Batch: 6, loss: 6.6809735
Batch: 0, loss: 7.5453634
Batch: 0, loss: 7.6443129
Batch: 0, loss: 7.8299255
Batch: 0, loss: 7.4813557
Batch: 0, loss: 7.5519705
Epoch: 22, Batch: 2, loss: 6.5051050
Batch: 0, loss: 7.6722832
Batch: 0, loss: 7.5782862
Batch: 0, loss: 7.7572470
Batch: 0, loss: 7.7314963
Epoch: 27, Batch: 7, loss: 6.3478270
Batch: 0, loss: 7.8330517
Batch: 0, loss: 7.9837060
Batch: 0, loss: 7.9353132
Batch: 0, loss: 7.8664732
Batch: 0, loss: 7.9199066
Epoch: 33, Batch: 3, loss: 6.1977282
Batch: 0, loss: 7.8347631
Batch: 0, loss: 8.1238775
Batch: 0, loss: 7.5302148
Batch: 0, loss: 7.6193390
Epoch: 38, Batch: 8, loss: 6.0909095
Batch: 0, loss: 8.0386648
Batch: 0, loss: 8.1929007
Batch: 0, loss: 8.3862934
Batch: 0, loss: 8.4985800
Batch: 0, loss: 8.5923557
Epoch: 44, Batch: 4, loss: 5.9588971
Batch: 0, loss: 8.4834251
Batch: 0, loss: 8.1670094
Batch: 0, loss: 8.1920309
Batch: 0, loss: 8.2342968
Batch: 0, loss: 8.1950932
Epoch: 50, Batch: 0, loss: 5.8471518
Batch: 0, loss: 8.5694637
Batch: 0, loss: 8.5044422
Batch: 0, loss: 8.2977266
Batch: 0, loss: 8.2102299
Epoch: 55, Batch: 5, loss: 5.7445393
Batch: 0, loss: 8.4679928
Batch: 0, loss: 8.5601730
Batch: 0, loss: 8.7446404
Batch: 0, loss: 8.7495289
Batch: 0, loss: 8.9133472
Epoch: 61, Batch: 1, loss: 5.6442862
Batch: 0, loss: 8.6313210
Batch: 0, loss: 8.5848646
Batch: 0, loss: 8.6187553
Batch: 0, loss: 8.4123878
Epoch: 66, Batch: 6, loss: 5.5716114
Batch: 0, loss: 8.6355677
Batch: 0, loss: 8.7246437
Batch: 0, loss: 8.5975513
Batch: 0, loss: 8.2936974
Batch: 0, loss: 8.1938820
Epoch: 72, Batch: 2, loss: 5.5084810
Batch: 0, loss: 8.6803112
Batch: 0, loss: 8.6566372
Batch: 0, loss: 8.4769630
Batch: 0, loss: 8.4978361
Epoch: 77, Batch: 7, loss: 5.4460464
Batch: 0, loss: 8.7400770
Batch: 0, loss: 8.6045074
Batch: 0, loss: 8.7099686
Batch: 0, loss: 8.9796181
Batch: 0, loss: 9.0241623
Epoch: 83, Batch: 3, loss: 5.3930869
Batch: 0, loss: 8.9092216
Batch: 0, loss: 8.5543833
Batch: 0, loss: 8.3348980
Batch: 0, loss: 8.4431782
Epoch: 88, Batch: 8, loss: 5.3444195
Batch: 0, loss: 8.9574356
Batch: 0, loss: 8.8693142
Batch: 0, loss: 8.8906231
Batch: 0, loss: 8.8772736
Batch: 0, loss: 8.5540447
Epoch: 94, Batch: 4, loss: 5.3144341
Batch: 0, loss: 8.2158041
Batch: 0, loss: 8.3749762
Batch: 0, loss: 8.4187984
Batch: 0, loss: 8.5182734
Batch: 0, loss: 8.6018639
Batch: 0, loss: 8.7100134
%{
  "pos_embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[256][768]
      EXLA.Backend
      [
        [-0.13128723204135895, -0.10325274616479874, 0.33050787448883057, 0.0012342628324404359, -0.2409193515777588, 0.1410435438156128, 0.08914312720298767, -0.1773548126220703, 0.17983700335025787, 0.27569761872291565, -0.11397161334753036, -0.11459209024906158, 0.13466870784759521, 0.13577212393283844, -0.2894219756126404, 0.08101727068424225, 0.017110124230384827, 0.028370676562190056, -0.26757296919822693, -0.22513410449028015, 0.34566888213157654, 0.2058556079864502, -0.4060262143611908, 0.13103951513767242, 0.08203025907278061, -0.15299487113952637, -0.015485299751162529, -0.2729659974575043, -0.020763149484992027, -0.3183610141277313, -0.27260249853134155, 0.09238959848880768, 0.092988982796669, -0.09268444031476974, 0.2952868342399597, -0.1661265641450882, 0.02826753258705139, 0.0012159794569015503, 0.26356807351112366, 0.18670181930065155, 0.17675593495368958, -0.07774572819471359, 0.040243688970804214, -0.018952257931232452, 0.020476512610912323, -0.042058661580085754, -0.11691082268953323, -0.191908061504364, ...],
        ...
      ]
    >
  },
  "normalization_6" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.976279616355896, 0.9983371496200562, 0.9347731471061707, 0.962969183921814, 0.7346729040145874, 1.008823275566101, 0.7395086884498596, 1.1268373727798462, 0.9809115529060364, 1.1624888181686401, 1.1489393711090088, 1.182489275932312, 1.1918779611587524, 0.8706818222999573, 0.8309196829795837, 1.1274809837341309, 1.0617780685424805, 0.9882297515869141, 0.9079322218894958, 0.7648198008537292, 1.107505440711975, 1.147552728652954, 0.6318383812904358, 1.1805849075317383, 1.039429783821106, 0.9723440408706665, 0.8223330974578857, 1.137025237083435, 1.0423537492752075, 1.0629721879959106, 1.1032946109771729, 0.8445615172386169, 1.1734566688537598, 1.284388780593872, 0.9392902255058289, 0.8950057625770569, 1.1202377080917358, 1.176619052886963, 1.1029174327850342, 1.1828274726867676, 0.9926555752754211, 0.8435946106910706, 1.0670576095581055, 0.9693421125411987, 0.8997613787651062, 0.952686071395874, 1.1207737922668457, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0656806230545044, -0.14769181609153748, -0.030020223930478096, -0.3033551275730133, -0.19672542810440063, -0.12948285043239594, -0.19817645847797394, 0.055259838700294495, -0.2091274857521057, -0.15433314442634583, 0.18365338444709778, 0.09990482777357101, -0.06615034490823746, 0.07417666912078857, 0.3781404495239258, 0.029660077765583992, 0.2673587203025818, -0.12904652953147888, -0.0253304410725832, 0.01580815389752388, 0.007705814205110073, -0.15006957948207855, -0.059319861233234406, -0.04121491312980652, 0.0825919657945633, 0.16814379394054413, -0.22064751386642456, -0.22826391458511353, -0.26785099506378174, -0.1615062952041626, -0.08773721009492874, -0.13997213542461395, -0.16767646372318268, 0.18819493055343628, -0.08408018946647644, 0.11714084446430206, 0.2842397391796112, -0.3574424386024475, -0.2914518415927887, 0.13251180946826935, -0.1806604266166687, -0.23264946043491364, -0.24279005825519562, 0.0787694975733757, -0.19106078147888184, 0.11505305022001266, ...]
    >
  },
  "dense_4" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.04355476424098015, 0.015225606970489025, 8.95271310582757e-4, -0.107908695936203, 0.11525387316942215, 6.082420004531741e-4, 0.08708442747592926, -0.002777251647785306, -0.14427299797534943, -0.06596463918685913, -0.004583990667015314, 0.013134156353771687, 0.07615090161561966, -0.05012185126543045, -0.031140994280576706, -0.18191498517990112, -0.09335724264383316, -0.06272120773792267, -0.03880956768989563, -0.08945201337337494, -0.026239190250635147, -0.11924758553504944, -0.027476347982883453, -0.09647303074598312, -0.20221006870269775, -0.05506009981036186, -0.107866570353508, 0.06856715679168701, 0.07217981666326523, -0.09557495266199112, 0.09120476990938187, -0.04235199838876724, -0.04429486393928528, 0.04204371199011803, -0.07016431540250778, 0.04530926048755646, 0.07431471347808838, -0.09846842288970947, -0.05621148645877838, 0.056164756417274475, -0.11062665283679962, -0.10559915751218796, 0.06870240718126297, -0.03556175157427788, -0.012295370921492577, 0.015994500368833542, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.030409662052989006, 0.06914470344781876, -0.03536171093583107, -0.040470510721206665, 0.214300736784935, 0.018398795276880264, 0.15517772734165192, -0.01885507069528103, -0.012821483425796032, 0.018264705315232277, -0.0031032587867230177, 5.74447913095355e-4, 0.0784018337726593, 0.038978610187768936, -0.038434647023677826, 0.05297403782606125, -0.017684968188405037, -0.004611591342836618, -0.08672378212213516, -0.053236182779073715, 0.022961461916565895, -0.025791959837079048, 0.05171163007616997, -0.04185943678021431, 0.05777277424931526, -0.0324106402695179, -0.0037037336733192205, 0.04541022703051567, 0.20738308131694794, -0.0066484748385846615, 0.008530541323125362, -0.02262500487267971, -0.009712018072605133, -0.0022372808307409286, -0.07329955697059631, 0.11780574917793274, 0.044119156897068024, 0.05288122594356537, 0.008561665192246437, -0.0129309743642807, -0.048855431377887726, -0.02897888608276844, 0.04306621477007866, 0.043353427201509476, 0.04641146957874298, ...],
        ...
      ]
    >
  },
  "normalization_5" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.669127345085144, 0.6805816292762756, 0.7976828217506409, 0.8352634906768799, 0.9580501914024353, 0.5050150156021118, 0.9693787693977356, 0.696363627910614, 0.8083791732788086, 0.8390010595321655, 0.8007017970085144, 0.870387077331543, 0.8499709367752075, 0.889473021030426, 0.8711001873016357, 0.5107711553573608, 0.8008232712745667, 0.8977949619293213, 0.7061402797698975, 0.4224337637424469, 0.7369155287742615, 0.6186343431472778, 1.1267280578613281, 0.6825594305992126, 0.7300768494606018, 0.6856794953346252, 1.0412360429763794, 0.859039843082428, 0.9437779784202576, 0.5632300972938538, 0.9921419620513916, 0.7559762001037598, 0.8173981308937073, 1.0903865098953247, 0.9968200325965881, 0.9936344027519226, 0.912320077419281, 0.9092993140220642, 0.9012632369995117, 0.7320505976676941, 0.937802255153656, 0.9421466588973999, 0.9922917485237122, 0.7487178444862366, 0.9281713366508484, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0031946797389537096, 0.12008783966302872, -0.12894879281520844, 0.1119285598397255, -0.10126188397407532, 0.051974039524793625, 0.029135316610336304, 0.15606839954853058, 0.07040970772504807, -0.04139747470617294, 0.08844441175460815, -0.14084434509277344, -0.025836080312728882, -0.2123531848192215, -0.2279147058725357, 0.02840578183531761, -0.016201728954911232, -0.026863791048526764, 0.1967143714427948, 0.17070244252681732, 0.038636527955532074, 0.18427211046218872, 0.0699673444032669, 0.029008477926254272, -0.018631385639309883, -0.057705219835042953, -0.07649531215429306, 0.09288813918828964, 0.19673500955104828, -0.026264110580086708, 0.014067205600440502, -0.020541660487651825, 0.03036583587527275, 0.0989772379398346, 0.12042278051376343, 0.01269456185400486, 0.09385214745998383, 0.1545635610818863, 0.1207835003733635, 0.003928808961063623, 0.08579232543706894, 0.1426999419927597, 0.14266203343868256, 0.06499376893043518, ...]
    >
  },
  "dense_20" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.0044816466979682446, -0.051838550716638565, 0.0041940053924918175, -0.018988460302352905, 0.045580409467220306, -0.022986318916082382, -0.06307004392147064, -0.011173025704920292, -0.04697903245687485, -0.028922008350491524, 0.06065948307514191, 7.540884544141591e-4, -0.03927231207489967, -0.04482021555304527, 0.01808062568306923, 0.026454975828528404, -0.037252724170684814, 0.054460812360048294, 0.03015664778649807, 0.09195298701524734, -0.008049468509852886, -0.02553507499396801, 0.009719045832753181, -0.021418455988168716, -0.09765792638063431, -0.02209583856165409, 1.0974994802381843e-4, -0.15184500813484192, 0.04014754295349121, -0.027318604290485382, 0.014955122955143452, -0.0375075526535511, 0.08267144113779068, 0.023335043340921402, -0.013761971145868301, 0.02540511079132557, -0.018841106444597244, 0.028025029227137566, 0.04785478115081787, -0.02252003364264965, 0.02121218666434288, -0.04236849769949913, -0.028721213340759277, 0.15094783902168274, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.035374224185943604, -0.20597676932811737, -0.02660421095788479, 0.022771045565605164, 0.06976617127656937, 0.012238830327987671, -0.0022192641627043486, -0.010858519934117794, 0.00867910124361515, -0.027534225955605507, -0.018232312053442, -0.015823666006326675, -0.03445006534457207, -0.02541390247642994, 0.01832074485719204, 0.03881501033902168, -0.057316992431879044, 0.14359959959983826, 0.04704686626791954, -0.026731492951512337, 0.038732875138521194, -0.060284651815891266, -0.051571354269981384, -6.788732134737074e-4, -0.2699446380138397, -0.03623112663626671, -0.0700412392616272, -0.10685974359512329, 0.034319620579481125, 0.0523863211274147, 0.010116176679730415, -0.010161573998630047, -0.07588993012905121, -0.032711148262023926, 0.0038086306303739548, 0.004615825600922108, 1.1983872536802664e-4, 0.04118034616112709, 0.07733182609081268, -0.030115796253085136, 0.023859065026044846, -0.04665843024849892, -0.04051484540104866, ...],
        ...
      ]
    >
  },
  "normalization_1" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.8882691264152527, 1.0090155601501465, 0.5405170917510986, 1.1575959920883179, 0.666041910648346, 0.8661141991615295, 0.7999036312103271, 0.6727589964866638, 0.7879235744476318, 0.8182779550552368, 0.9743760824203491, 0.5371227860450745, 0.6991158127784729, 0.4712108075618744, 0.8720059394836426, 0.8539559245109558, 0.9083633422851562, 1.075925350189209, 0.7013532519340515, 0.6780218482017517, 0.7569238543510437, 0.8280938267707825, 0.9199353456497192, 0.5841556787490845, 0.5195330381393433, 0.900633692741394, 0.5795818567276001, 0.647293746471405, 0.8817631006240845, 0.8172031044960022, 0.9616697430610657, 0.5880069732666016, 0.5644152164459229, 0.7345616817474365, 0.706516683101654, 1.0417511463165283, 0.2774600684642792, 0.4580109417438507, 0.8110355734825134, 1.0161628723144531, 0.9425165057182312, 0.8503072261810303, 0.8294627666473389, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.2735598385334015, 0.18094606697559357, -0.4384869635105133, 0.0071744672022759914, -0.23785264790058136, -0.6014081835746765, -0.15710879862308502, 0.08121699094772339, -0.4111187756061554, 0.4210231602191925, 0.37548407912254333, 0.1679743230342865, -0.27268028259277344, 0.3791903853416443, -0.08484373986721039, -0.36037853360176086, 0.026189561933279037, -0.08442389219999313, -0.012706195935606956, 0.005214248783886433, 0.3358478248119354, 0.20001326501369476, 0.3746623694896698, 0.4028330147266388, -0.31535962224006653, 0.14244164526462555, -0.4729384481906891, 0.35011428594589233, -0.06551029533147812, -0.3472256064414978, 0.21798096597194672, -0.3236621022224426, 0.09019414335489273, 0.25643134117126465, -0.27338817715644836, -0.12381718307733536, 0.09456278383731842, 0.5233346819877625, 0.5095019340515137, 0.4905881881713867, 0.23646800220012665, 0.006336736027151346, ...]
    >
  },
  "normalization_9" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.7601646780967712, 1.0575873851776123, 0.833929717540741, 0.8752350211143494, 1.1219244003295898, 0.8563557267189026, 1.008758306503296, 1.1346663236618042, 0.5930135846138, 0.6713327765464783, 1.0255649089813232, 1.0547596216201782, 1.3281452655792236, 0.7083097100257874, 0.648703396320343, 0.9649523496627808, 0.6726261973381042, 0.9629827737808228, 0.6620683073997498, 0.6895036697387695, 0.7732593417167664, 0.8105130791664124, 1.036435604095459, 0.8421568274497986, 0.7799838781356812, 0.7684429287910461, 0.725635290145874, 0.9275967478752136, 0.8096864223480225, 0.5955697894096375, 0.7125076651573181, 0.6794856786727905, 1.3544409275054932, 1.0881547927856445, 0.8968175053596497, 0.8811891078948975, 0.8572181463241577, 0.9301641583442688, 0.5367276072502136, 1.1217517852783203, 0.8602713346481323, 0.6892555952072144, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.17955945432186127, 0.039918284863233566, 0.05118217319250107, 0.1689285933971405, 0.017944535240530968, 0.18175041675567627, 0.18090574443340302, 0.08146417140960693, 0.12605684995651245, -0.1746881753206253, 0.09894595295190811, -0.11336728930473328, -0.03381138667464256, -0.3609730005264282, -0.1554592102766037, 0.24808934330940247, -0.22175371646881104, 0.16448788344860077, 0.4377762973308563, 0.44366177916526794, 0.1966020166873932, 0.2805810570716858, 0.04645632579922676, -0.15196315944194794, 0.16270124912261963, -0.011697367765009403, -0.03425142914056778, 0.0769764855504036, 0.08929343521595001, 0.22643645107746124, -0.2901487946510315, 0.026222683489322662, -0.09797918796539307, 0.055557940155267715, 0.4047852158546448, 0.15461792051792145, -0.13756106793880463, 0.01601230725646019, -0.0022855973802506924, 0.05769926682114601, -0.010181233286857605, ...]
    >
  },
  "dense_22" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.025983087718486786, -0.014882506802678108, 0.03572889417409897, -0.03285176679491997, -0.07769246399402618, -0.017144272103905678, -0.030782755464315414, 0.054986607283353806, 0.0820801854133606, -0.002821727190166712, 0.0470515713095665, -0.022503044456243515, -0.012873765081167221, -0.03375611826777458, 0.08271979540586472, 0.06973311305046082, 0.05366113781929016, 0.058859143406152725, 0.015187464654445648, -0.02380977012217045, -0.030898567289114, -0.014092245139181614, -0.02445690706372261, 0.049474433064460754, 0.05607358366250992, 0.09577149897813797, 0.08512849360704422, -0.03195606544613838, -0.026346810162067413, -0.02108968235552311, -0.03384871035814285, 0.03208340331912041, 0.031305957585573196, -0.0047063566744327545, -0.021106021478772163, -0.0017622661544010043, -0.025616023689508438, -0.027356192469596863, -0.01818583719432354, -0.02201717719435692, -0.08372113853693008, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.041509054601192474, -0.07663726806640625, -0.01721254736185074, 0.004556768108159304, -0.04343554750084877, 0.006617851555347443, -0.042958639562129974, 0.011971570551395416, 0.06218034029006958, -0.005412611179053783, 0.011948412284255028, -0.010113690048456192, 0.055139198899269104, -0.047646135091781616, -0.12393534928560257, -0.036856722086668015, -0.016968408599495888, -0.03729138895869255, 0.00813064444810152, -0.004013491794466972, -0.015281339175999165, -0.01608598791062832, -0.009003383107483387, -0.019049592316150665, -0.06853911280632019, -0.014113642275333405, -0.00767920445650816, 0.016962898895144463, -0.027451511472463608, -0.05491471290588379, -0.057572830468416214, 0.06901074200868607, 0.012415540404617786, 0.11876581609249115, -0.024727026000618935, 1.0221428965451196e-4, 0.06175108999013901, -0.0686023160815239, 0.012735210359096527, 0.0020371642895042896, ...],
        ...
      ]
    >
  },
  "causal_attention_0" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.11505205184221268, 0.2514360547065735, -0.11258097738027573, -0.14808319509029388, 0.33139365911483765, -0.24039693176746368, 0.20929482579231262, -0.17939041554927826, -0.06635420769453049, 0.1525164097547531, -0.005929156206548214, -0.0535774752497673, -0.5411754846572876, 0.02358352020382881, -0.17205773293972015, -0.07191508263349533, -0.13897471129894257, 0.03905788064002991, 0.3725956678390503, 0.017313851043581963, 0.054634060710668564, -0.15265773236751556, -0.019046548753976822, -4.179324896540493e-4, -0.06938805431127548, 0.04996294528245926, -0.1627296358346939, 0.014806467108428478, -0.026507647708058357, -0.14492377638816833, 0.25210678577423096, -0.2958992123603821, -0.09973254054784775, -0.13291919231414795, -0.2708480656147003, -0.010680124163627625, 0.08697282522916794, 0.13701745867729187, 0.1424722969532013, 0.04901214316487312, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.17492826282978058, 0.5080336928367615, -0.5048732161521912, 0.35095804929733276, -0.16526944935321808, -0.4512418508529663, -0.17394514381885529, -0.1938553750514984, -0.1139238253235817, 0.14298734068870544, -0.1942778378725052, -0.019794659689068794, -0.28326359391212463, -0.08129452168941498, 0.2588968575000763, 0.05206902697682381, -0.4140560030937195, 0.2815587520599365, -0.35883986949920654, 0.2350740134716034, 0.2576012909412384, 0.28182607889175415, 0.14827598631381989, -0.26348546147346497, -0.09118418395519257, 0.2349678874015808, 0.39138567447662354, -0.05817278474569321, 0.06031375378370285, -0.030808134004473686, 0.4431280195713043, 0.30415162444114685, 0.311904639005661, -8.275318396044895e-5, -0.4624999165534973, 0.2631886899471283, -0.20660735666751862, -0.25670573115348816, 0.21840710937976837, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05208028480410576, 0.39713969826698303, -0.28976747393608093, 0.15262018144130707, -0.002116926945745945, -0.31039494276046753, -0.11510646343231201, -0.2538120746612549, 0.11721578240394592, -0.003938170615583658, -0.09595898538827896, -0.23389551043510437, -0.10887853056192398, -0.11900849640369415, 0.12880443036556244, -0.010344219394028187, -0.21714043617248535, 0.18059201538562775, -0.053492628037929535, 0.09465359151363373, 0.06434983760118484, 0.0720265582203865, -0.13290490210056305, -0.07528096437454224, -0.12849026918411255, 0.2925311326980591, 0.06788992881774902, 0.051520515233278275, -0.18984587490558624, 0.09663024544715881, 0.07977241277694702, 0.025895042344927788, 0.10449623316526413, -0.03415248543024063, -0.3716948628425598, 0.0952053964138031, -0.14616499841213226, -0.12325423210859299, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.2895628809928894, 0.1582932323217392, -0.09316343069076538, -0.3422802984714508, 0.36091434955596924, -0.41932979226112366, -0.10843870043754578, -0.38483113050460815, 0.20754651725292206, 0.33959946036338806, 0.04828161001205444, -0.21546930074691772, 0.29844871163368225, 0.0262361578643322, 0.39029693603515625, -0.34534382820129395, 0.040186699479818344, -0.15093274414539337, 0.1076795905828476, 0.26330533623695374, -0.40137287974357605, -0.3775118291378021, 0.2689940631389618, 0.31608307361602783, 0.10556785762310028, -0.28570038080215454, -0.18723583221435547, 0.2719678282737732, 0.3283793032169342, -0.30990058183670044, -0.1267853081226349, -0.2505991458892822, -0.3339806795120239, 0.30719900131225586, 0.31037914752960205, 0.4498017430305481, -0.4520857632160187, ...],
        ...
      ]
    >
  },
  "normalization_14" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.7688844203948975, 1.4032864570617676, 1.1531037092208862, 1.0858877897262573, 1.355336308479309, 1.1822845935821533, 1.217619776725769, 1.1940451860427856, 0.7554122805595398, 0.6242034435272217, 1.0660746097564697, 1.2633442878723145, 0.8631431460380554, 1.2257742881774902, 0.9718759655952454, 0.8947205543518066, 1.039861798286438, 1.1669151782989502, 0.6018592715263367, 0.7097852230072021, 1.019943118095398, 0.8997196555137634, 0.734862208366394, 0.9951802492141724, 0.7384564876556396, 1.2744379043579102, 1.311685562133789, 1.1607112884521484, 1.1183061599731445, 0.7916414141654968, 0.7345647215843201, 1.158532977104187, 1.284753680229187, 1.1233901977539062, 1.2105326652526855, 0.8009575605392456, 1.3823654651641846, 1.391650676727295, 1.326175332069397, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.3637162744998932, 0.0017139220144599676, 0.24348455667495728, 0.03816577419638634, 0.02377811074256897, 0.25012022256851196, 0.04602523893117905, 0.18923906981945038, 0.1049588993191719, -0.3769398629665375, 0.07236865162849426, -0.11038442701101303, 0.2795722484588623, -0.16120971739292145, 0.10311773419380188, 0.22906631231307983, 0.093011774122715, 0.1225256621837616, 0.43358004093170166, 0.3958436846733093, 0.4171266257762909, -0.0787905678153038, -0.007767575327306986, -0.3188043534755707, 0.37655386328697205, 0.02017715945839882, -0.013378755189478397, -0.27999347448349, -0.11878695338964462, 0.33813223242759705, -0.306076318025589, 0.3352339267730713, -0.06761858612298965, -0.3191472589969635, 0.18086998164653778, 0.2953146696090698, 0.08243796974420547, -0.08437398821115494, ...]
    >
  },
  "normalization_7" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.7287175059318542, 0.953284502029419, 0.8301360607147217, 1.0001031160354614, 0.9228729009628296, 0.8907386064529419, 0.9352411031723022, 1.0361543893814087, 1.0143001079559326, 0.6863902807235718, 1.2018766403198242, 1.0922095775604248, 1.3778258562088013, 0.807249903678894, 0.5378776788711548, 0.7783093452453613, 0.6332129240036011, 0.9725655317306519, 0.7596699595451355, 0.764099657535553, 0.7365170121192932, 0.6895506381988525, 1.1963672637939453, 0.7400028109550476, 0.8340644836425781, 0.7447218894958496, 0.907943844795227, 0.8062208294868469, 0.7738038897514343, 0.7783903479576111, 0.5778080821037292, 0.6866471767425537, 0.7758758664131165, 0.8342713713645935, 1.0022727251052856, 0.7005257606506348, 0.6443851590156555, 0.728439211845398, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.13799762725830078, 0.15667153894901276, -0.3177345395088196, -0.06135115399956703, 0.07031365483999252, -0.18177258968353271, 0.006121802143752575, -0.005101750139147043, -0.15698660910129547, -0.020436588674783707, 0.1478266716003418, -0.07381902635097504, -0.23566430807113647, -0.3654991388320923, -0.45716872811317444, 0.16072525084018707, -0.07424641400575638, -2.510088961571455e-5, 0.5276729464530945, 0.2836211621761322, 0.3169234097003937, 0.3603701591491699, 0.08403345197439194, 0.12383824586868286, -0.012530870735645294, 0.11565893143415451, -0.29156070947647095, 0.14054006338119507, 0.2861216068267822, -0.024329056963324547, -0.2751427888870239, -0.3304060101509094, -0.025516996160149574, 0.0074087632820010185, 0.19204401969909668, 0.2688041925430298, 0.05042148008942604, ...]
    >
  },
  "dropout_2" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [1071914113, 394885060]
    >
  },
  "dropout_14" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [1676324637, 2169359021]
    >
  },
  "normalization_11" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.6537600755691528, 1.021589756011963, 0.8740715980529785, 1.0782865285873413, 0.867885410785675, 0.934373140335083, 0.9258654713630676, 1.2223132848739624, 1.0549830198287964, 0.6443033814430237, 0.8942501544952393, 1.066359281539917, 1.23695969581604, 0.7287760972976685, 0.8780415654182434, 0.7896318435668945, 0.8951253890991211, 0.602243959903717, 0.7258689999580383, 0.8550219535827637, 0.7595850229263306, 0.9387528896331787, 1.3649874925613403, 0.8345512747764587, 0.9940775632858276, 0.9451969861984253, 0.9120372533798218, 0.8794366121292114, 0.7566035985946655, 1.0275284051895142, 0.656404435634613, 0.8756612539291382, 1.206744909286499, 1.0928542613983154, 0.9749126434326172, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.1585337370634079, 0.1701352447271347, -0.07980729639530182, 0.33637097477912903, 0.04268687218427658, 0.06791402399539948, -0.01671365648508072, -0.15008039772510529, 0.013158587738871574, -0.3382038474082947, -0.1124115139245987, 0.10749763250350952, -0.22406715154647827, -0.22823086380958557, -0.16003911197185516, 0.18537285923957825, -0.14479972422122955, -0.051999084651470184, 0.2646397650241852, 0.23121966421604156, 0.1510889232158661, 0.2291404902935028, 0.07818228751420975, -0.07321730256080627, 0.08364376425743103, -0.14352895319461823, 0.11234854906797409, 0.04303711652755737, 0.2061213254928589, 0.16729316115379333, -0.3431406617164612, 0.09662745893001556, 0.16902518272399902, 0.008093971759080887, ...]
    >
  },
  "normalization_2" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.9955676794052124, 0.9743684530258179, 1.2751744985580444, 0.6833834648132324, 1.2043081521987915, 1.046404242515564, 1.2768059968948364, 1.1447440385818481, 0.8881211280822754, 0.7418510913848877, 0.7313302755355835, 0.8148407936096191, 1.1637109518051147, 0.8906869292259216, 0.8815494179725647, 0.6707848310470581, 0.8231105208396912, 0.9543188214302063, 0.9810612201690674, 0.6214656233787537, 0.7477407455444336, 0.9520954489707947, 0.9968622326850891, 0.8744527101516724, 1.0231186151504517, 1.074076771736145, 1.055703043937683, 0.6110471487045288, 0.8233792781829834, 1.069803237915039, 0.8296988010406494, 1.119225025177002, 0.9564214944839478, 0.962261974811554, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.02593332529067993, 0.001940784975886345, -0.1885371208190918, 0.06457509845495224, -0.06946443766355515, 0.08030861616134644, -0.005466314498335123, 0.006403574254363775, 0.07841882854700089, -0.0504097081720829, 0.0359046533703804, -0.33440303802490234, 0.27101728320121765, -0.05079322308301926, -0.2537595331668854, -0.1052444726228714, -0.3589431047439575, -0.21315552294254303, 0.234177827835083, 0.17828381061553955, 0.14291547238826752, 0.22692500054836273, -0.2849924862384796, 0.06333529949188232, -0.08192180842161179, 0.06410709768533707, -0.07209733873605728, 0.08725271373987198, 0.18199002742767334, -0.03335550054907799, -0.21029230952262878, -0.18178102374076843, 0.08496683835983276, ...]
    >
  },
  "dropout_9" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [2683469369, 916559154]
    >
  },
  "dropout_11" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [2387778865, 1306972002]
    >
  },
  "dense_12" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.024731086567044258, -0.015216008760035038, -0.07290135324001312, -0.010556008666753769, 0.01640007458627224, -0.08263561874628067, -0.010143207386136055, -0.002852850593626499, -0.034865353256464005, -0.0044746408239007, -0.03093351423740387, -0.15647418797016144, -0.011149008758366108, -0.011131196282804012, 0.031118500977754593, 0.081869937479496, -0.01122331153601408, 0.0027156381402164698, -0.07862426340579987, -0.005347146652638912, -0.04251427203416824, -0.05761629715561867, -0.03124002367258072, -0.061618901789188385, 0.025258928537368774, -0.04266304895281792, -0.03528517112135887, 0.12068681418895721, 0.0228628758341074, -0.036252573132514954, -0.03249562904238701, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.008217046968638897, -0.05246155709028244, -0.05548520386219025, -0.009699796326458454, -0.06511358171701431, -0.04095799848437309, -6.137189920991659e-4, -2.9471394373103976e-4, 0.013986802659928799, -0.10009171813726425, -0.05554055795073509, -0.0636277124285698, -0.05083630979061127, -0.02393398992717266, 0.024377714842557907, -0.12548141181468964, -0.026375215500593185, -0.09152113646268845, -0.055831145495176315, 0.03127060830593109, -0.015754524618387222, -0.002821930916979909, -0.047330837696790695, 0.003117528511211276, 0.04197460412979126, -0.05865616723895073, -0.0677279680967331, -0.0294184572994709, -0.05238322541117668, 0.0015352993505075574, ...],
        ...
      ]
    >
  },
  "dense_23" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.009843180887401104, 0.0011251100804656744, -0.002210816368460655, -0.01815475896000862, -0.01641346700489521, -0.006333646364510059, -0.026434822008013725, -0.029409456998109818, -0.03262704610824585, 0.054779715836048126, -0.01769084297120571, 0.008352200500667095, 0.03344497084617615, 0.018778741359710693, -0.009678652510046959, -0.00841380376368761, 0.03236386179924011, 0.036636654287576675, -0.015711862593889236, 0.03333629295229912, -0.018676333129405975, -0.025183822959661484, -0.02007056027650833, 0.01262013241648674, -0.03137780353426933, 0.014069235883653164, 0.031322162598371506, -0.014229623600840569, -0.01966366544365883, -0.018201271072030067, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.01132954005151987, 0.021764900535345078, 0.042097728699445724, 0.0021775781642645597, 0.03637059032917023, -5.289832479320467e-4, -0.003559721168130636, -0.013169883750379086, 0.031209155917167664, 0.006671429146081209, 0.0034836085978895426, -0.006546341348439455, -0.02274567261338234, -0.03273801878094673, 0.02541232854127884, 0.047575660049915314, -0.011794036254286766, -0.019173212349414825, -0.027178920805454254, -0.023004600778222084, -4.909464041702449e-4, -0.006986131425946951, -0.0030406611040234566, -0.0038151622284203768, 0.01399194821715355, 0.021934526041150093, -0.021527253091335297, -0.014740190468728542, 0.02101035602390766, ...],
        ...
      ]
    >
  },
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[50257][768]
      EXLA.Backend
      [
        [-0.32303544878959656, -0.1547531932592392, 0.15046918392181396, 0.03694256395101547, -0.4073770344257355, 0.40883639454841614, -0.05042625963687897, -0.13812848925590515, -0.44260936975479126, 0.042683906853199005, 0.09064517915248871, -0.17477451264858246, -0.028587598353624344, 0.36422139406204224, -0.24332250654697418, 0.4001982808113098, -0.13118943572044373, -0.34468162059783936, -0.17254599928855896, -0.19412396848201752, -0.040157563984394073, 0.19192802906036377, 0.22668500244617462, 0.21164163947105408, 0.32522645592689514, -0.11300615966320038, -0.13457538187503815, -0.3376449644565582, 0.16616423428058624, ...],
        ...
      ]
    >
  },
  "dropout_8" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [3904612005, 59822595]
    >
  },
  "normalization_19" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.8786206245422363, 0.977189838886261, 0.9535285234451294, 0.9416677951812744, 0.8935660719871521, 1.1148587465286255, 1.1209862232208252, 1.0759830474853516, 1.2573795318603516, 0.7130045890808105, 1.0781350135803223, 1.311759352684021, 1.7548468112945557, 0.9651852250099182, 1.0420305728912354, 1.1724156141281128, 0.8893648982048035, 0.975919246673584, 1.0375721454620361, 1.1725431680679321, 0.8626344203948975, 1.3305847644805908, 1.2337439060211182, 0.9258536100387573, 0.8689802885055542, 1.1040157079696655, 1.1660010814666748, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.20075379312038422, 0.06406572461128235, 0.12223661690950394, 0.0578937828540802, 0.08842409402132034, 0.13347914814949036, 0.05970439314842224, -0.029217541217803955, 0.09374240785837173, -0.21513016521930695, -0.018488746136426926, 0.0017099551623687148, -0.08513133972883224, -0.09809098392724991, 0.04233935847878456, 0.1992577463388443, -0.15861204266548157, 0.08759336918592453, 0.23087632656097412, 0.1191282793879509, 0.12830597162246704, 0.016102248802781105, 0.011360338889062405, -0.13128332793712616, 0.23939116299152374, -0.023765666410326958, ...]
    >
  },
  "dropout_6" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [49463572, 2319138657]
    >
  },
  "dropout_23" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [496009241, 567025577]
    >
  },
  "causal_attention_6" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.005474309902638197, -0.008159071207046509, 0.04245653375983238, -0.003517773235216737, 0.02489471435546875, 0.07975700497627258, 0.00429315073415637, 0.05666806921362877, 0.03440393880009651, -0.07256941497325897, 0.0659099668264389, -0.04771120101213455, 0.008794579654932022, -0.07758080959320068, 0.018121356144547462, -0.006548670586198568, -0.015661224722862244, -0.04024868458509445, 0.02129288949072361, -0.04879722744226456, 0.025005873292684555, 0.005623295437544584, 0.04078862816095352, -0.010245051234960556, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.25480490922927856, 0.0068677314557135105, -0.36943551898002625, -0.2532135248184204, -0.2968670427799225, -0.45883166790008545, -0.39619356393814087, -0.2941643297672272, 0.31375232338905334, 0.28081953525543213, -0.35271313786506653, 0.15246084332466125, 0.3486085832118988, -0.2409464418888092, 0.3810553252696991, 0.18756243586540222, -0.23538103699684143, -0.16381123661994934, 0.2577248811721802, 0.31122246384620667, -0.19215799868106842, -0.31090375781059265, 0.3078189492225647, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.1683349758386612, -0.07505340129137039, -0.1776185929775238, -0.3298611640930176, -0.18433807790279388, -0.16616573929786682, -0.21618106961250305, -0.3471245765686035, 0.4272610545158386, 0.16816367208957672, -0.14912068843841553, 0.23053409159183502, 0.1724993735551834, -0.2502102553844452, 0.1889902502298355, 0.07785528898239136, -0.17520083487033844, 4.069672722835094e-4, 0.2340579777956009, 0.20814143121242523, -0.09603521227836609, -0.3189537227153778, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.026001185178756714, 0.1164478287100792, 0.01576114073395729, 3.8986621075309813e-4, 0.05665240064263344, -0.06860406696796417, -0.077884741127491, 0.0020160211715847254, 0.1010938435792923, -0.00489264540374279, 0.014109453186392784, 0.004938635043799877, 0.013957269489765167, -1.0953181481454521e-4, 0.045285310596227646, 0.006409699097275734, -0.12221642583608627, -0.030588507652282715, -0.015953488647937775, -0.008744288235902786, 0.06652205437421799, ...],
        ...
      ]
    >
  },
  "causal_attention_7" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.014401543885469437, -0.08210781961679459, 0.03707284852862358, 0.04480956122279167, -0.025309355929493904, 0.03636057674884796, 0.06515347212553024, 0.034062087535858154, 0.056122615933418274, -0.12001346796751022, -0.01367500051856041, 0.03935448080301285, 0.025217553600668907, -0.003645562566816807, -0.06464672088623047, -0.032569948583841324, -0.037425216287374496, -0.05927759036421776, -0.028583591803908348, -0.06957610696554184, -0.02805006317794323, 0.0031358744017779827, 0.08971302956342697, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.17740115523338318, -0.26783549785614014, -0.21780116856098175, 0.09614187479019165, -0.0841803178191185, 0.16642539203166962, 0.16042102873325348, 0.06820662319660187, -0.30418792366981506, -0.18555568158626556, 0.1549672782421112, -0.03423169255256653, 0.18163011968135834, -0.22919447720050812, 0.16750094294548035, 0.2446310967206955, 0.08903711289167404, -0.21513791382312775, -0.19718340039253235, 0.19014732539653778, 0.11035270243883133, 0.19292938709259033, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.13402864336967468, 0.23710651695728302, 0.25655093789100647, -0.21579796075820923, 0.20090727508068085, -0.09114925563335419, -0.10657808184623718, 0.014876720495522022, 0.05937391147017479, 0.241450235247612, -0.22072163224220276, -0.1562371402978897, -0.37841227650642395, 0.36161449551582336, 0.07387161254882812, -0.2070108950138092, -0.07916700839996338, -0.0025227591395378113, 0.0285616684705019, -0.13905370235443115, -0.12145528942346573, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.050232306122779846, -0.024067915976047516, -0.07783021777868271, 0.06960273534059525, -0.03935161605477333, -0.05371995270252228, 0.043202634900808334, 0.0760946273803711, -0.03973529115319252, 0.01737343519926071, 0.034734684973955154, -0.003634505206719041, 0.03177136182785034, -0.1413903832435608, -0.07761117815971375, 0.0020925807766616344, -0.05971799045801163, -0.03562401235103607, -0.0256649199873209, -0.06548570841550827, ...],
        ...
      ]
    >
  },
  "causal_attention_1" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.07838151603937149, -0.0261519905179739, 0.02588944509625435, -0.006915892008692026, -0.0021938197314739227, -0.09281930327415466, -0.05995152145624161, -0.07047520577907562, -0.13170771300792694, 0.026708200573921204, -0.03788400813937187, -0.0403260737657547, 0.08190172910690308, 0.04935585334897041, 0.13235844671726227, -0.052911385893821716, 0.06833945959806442, -0.1433049738407135, -0.0913701057434082, -0.07202108204364777, -0.09943951666355133, -0.01671873778104782, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.1804468184709549, 0.4604094922542572, 0.15288865566253662, 0.10773996263742447, 0.14474307000637054, 0.37515804171562195, -0.2118193656206131, -0.22631989419460297, 0.5530677437782288, 0.2164766788482666, -0.4044777452945709, 0.31179162859916687, -0.5569564700126648, -0.24661296606063843, -0.1504637748003006, 0.48781636357307434, -0.059213511645793915, 0.062315311282873154, 0.4075791835784912, -0.3149406909942627, -0.16357609629631042, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.07258270680904388, -0.00496251555159688, -0.015258222818374634, 0.1958148330450058, -0.08403982222080231, 0.021955518051981926, 0.13028812408447266, 0.08879725635051727, -0.14083105325698853, -0.08326917141675949, 0.07492461055517197, -0.06772874295711517, 0.05693696439266205, -0.01142183504998684, -0.06184885650873184, -0.11089394986629486, -0.004206709563732147, -0.053158484399318695, -0.02360639162361622, 0.05678749457001686, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0628456324338913, -0.1594214290380478, 0.1831291764974594, -0.1020272746682167, 0.009997512213885784, 0.20362673699855804, 0.024814946576952934, 0.08186056464910507, 0.18390952050685883, 0.1218249648809433, 0.03355574607849121, 0.16193200647830963, 0.1030723825097084, -0.09385787695646286, 0.0250809695571661, 0.011946634389460087, 0.09945859760046005, 0.009870940819382668, 0.0014449091395363212, ...],
        ...
      ]
    >
  },
  "dropout_1" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [773735686, 1298200513]
    >
  },
  "dense_15" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.033280134201049805, 0.0041899653151631355, -0.013486161828041077, 0.014018313027918339, -0.022471409291028976, -0.014006017707288265, -0.008320542983710766, -0.030460044741630554, -0.028545960783958435, 0.04536983370780945, -0.02982671558856964, 0.01696634478867054, 0.018789168447256088, 0.016128091141581535, -0.015098990872502327, -0.022569550201296806, 0.023259006440639496, 0.03751234710216522, -0.026029782369732857, 0.02426978386938572, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.03769754618406296, 0.0016493636649101973, 0.03393382206559181, -0.026154886931180954, -0.06716742366552353, 0.028866611421108246, -0.006096403580158949, -0.058365460485219955, 0.018212120980024338, 0.04810015857219696, -0.018201809376478195, 0.011072233319282532, 0.024075524881482124, 0.005004778504371643, 0.028287699446082115, 0.02789614349603653, 0.01989627256989479, 0.022731946781277657, -0.009285451844334602, ...],
        ...
      ]
    >
  },
  "dense_9" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.03121029958128929, 0.004105882253497839, -0.013922619633376598, -0.02844170480966568, -0.0102127930149436, -0.01797102391719818, -0.01386560034006834, -0.02791016735136509, -0.03187980130314827, 0.04305558651685715, -0.029218971729278564, 0.007087813224643469, 0.029190437868237495, 0.013355233706533909, -0.03826792910695076, -0.025429747998714447, 0.02941240556538105, 0.05568688362836838, -0.027958227321505547, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.07633969187736511, -0.04240958020091057, 0.08034719526767731, 0.04090924933552742, 0.08719650655984879, 0.03187438100576401, 0.04812890663743019, 0.032732464373111725, -0.0383421927690506, -0.06913375109434128, -0.027150096371769905, -0.13525985181331635, 0.019025051966309547, -0.07607143372297287, 0.057940926402807236, 0.024896077811717987, -0.005658571608364582, 0.10768505185842514, ...],
        ...
      ]
    >
  },
  "normalization_8" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.2404940128326416, 0.8329317569732666, 1.3361403942108154, 0.6342620849609375, 0.8047819137573242, 1.1019365787506104, 0.834610641002655, 1.0301810503005981, 0.8308852314949036, 0.94856858253479, 1.0262738466262817, 1.2845882177352905, 0.8550918698310852, 0.734901487827301, 0.8214701414108276, 0.9203702211380005, 0.5292835235595703, 1.032288670539856, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.17467853426933289, -0.04344610869884491, -0.2161131501197815, -0.14058817923069, -0.17945833504199982, -0.2036036103963852, 0.019165921956300735, -0.11127305030822754, -0.06410326808691025, -0.012674011290073395, -0.07838745415210724, 0.11232764273881912, 0.27520084381103516, 0.26874393224716187, 0.08866811543703079, -0.2706127166748047, 0.18268372118473053, ...]
    >
  },
  "dropout_10" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [3792903481, 3394597351]
    >
  },
  "dropout_3" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [921497900, 2456842150]
    >
  },
  "dropout_15" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [1028020613, 3568050902]
    >
  },
  "causal_attention_10" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.01847536675632, 0.08848455548286438, 0.05879909545183182, -0.011100459843873978, -0.0026253913529217243, 0.07538871467113495, 0.0388086698949337, -0.048253655433654785, 0.04075698181986809, 0.0419774055480957, 0.011266776360571384, 0.09800852090120316, 0.08451832830905914, 0.014776917174458504, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.0586051382124424, -0.08797039836645126, -7.342506432905793e-4, -0.17470332980155945, -0.005726651754230261, 0.08396408706903458, -0.09287691861391068, 0.04687626287341118, -0.027748137712478638, 0.15278883278369904, 0.07859783619642258, 0.09883075952529907, 0.03740178793668747, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.03305619955062866, 0.04778533801436424, 0.025493618100881577, 0.10078153759241104, -0.0767362043261528, -0.027010340243577957, 0.13192489743232727, -0.06664367765188217, -0.004256818443536758, -0.09519480913877487, -0.06090454384684563, -0.006716672331094742, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.061558518558740616, 0.021770888939499855, -0.02040247805416584, 0.06100579723715782, -0.004383729305118322, -0.025383440777659416, -0.005708214361220598, 0.050491951406002045, 0.09237834066152573, -0.03344252333045006, -0.024055540561676025, ...],
        ...
      ]
    >
  },
  "causal_attention_8" => %{
    "out_proj" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04284971207380295, -0.02115190401673317, 0.03848973661661148, -0.06188362464308739, -0.059944137930870056, 0.01264828909188509, 0.014585893601179123, -0.04924559220671654, 0.016959449276328087, 0.10092634707689285, -0.04067743569612503, 0.0591898150742054, -0.03991405665874481, ...],
        ...
      ]
    >,
    "w_key" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.16719892621040344, -0.32467323541641235, 0.03138533607125282, -0.09619949012994766, -0.15892265737056732, 0.28406986594200134, -0.15274712443351746, -0.008167644962668419, -0.4055551588535309, 0.18020747601985931, -0.17026761174201965, -0.12456604838371277, ...],
        ...
      ]
    >,
    "w_query" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.25998005270957947, 0.05616820603609085, -0.29344120621681213, -0.05881772190332413, 0.2063792645931244, -0.2858749032020569, 0.14600811898708344, -0.012882765382528305, 0.10957272350788116, -0.2004784792661667, 0.14097826182842255, ...],
        ...
      ]
    >,
    "w_value" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.030553070828318596, 0.041534774005413055, -0.03597399219870567, -0.08644933998584747, 0.024126175791025162, -0.01605439931154251, -0.04043141379952431, 0.026280589401721954, 0.08956951647996902, -0.08002328872680664, ...],
        ...
      ]
    >
  },
  "dropout_7" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [1517465556, 680546948]
    >
  },
  "normalization_12" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [1.0005691051483154, 0.8601945042610168, 1.3499969244003296, 0.6964654922485352, 1.1036350727081299, 1.150476336479187, 0.8913223147392273, 1.1150590181350708, 0.8373574018478394, 0.8103222250938416, 0.7924319505691528, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.1763605773448944, -0.10315840691328049, -0.1653309166431427, -0.08805394917726517, -0.3151269555091858, -0.259617418050766, -0.4000778794288635, -0.01937214657664299, -0.20391380786895752, -0.014705691486597061, ...]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.03717408329248428, -5.34638820681721e-4, -0.020702335983514786, -0.05229564756155014, -0.004672816954553127, -0.03581037372350693, -0.02960226498544216, -0.010178990662097931, -0.019440125674009323, 0.032921042293310165, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.059313539415597916, -0.03014758974313736, -0.030082615092396736, 0.01212656032294035, 0.014297493733465672, -0.0032631224021315575, -0.097275011241436, -0.0011344181839376688, -0.00340064219199121, ...],
        ...
      ]
    >
  },
  "dense_18" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [0.013365418650209904, 0.034775666892528534, 6.796774105168879e-4, 0.04617267847061157, 0.003196151228621602, 0.04513091221451759, 0.01421383861452341, -0.011356797069311142, -0.04627877473831177, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.09282093495130539, 0.04957180097699165, 0.0023574030492454767, -0.04585099592804909, 0.0576811358332634, 0.030636806041002274, -0.07704762369394302, 0.006161978002637625, ...],
        ...
      ]
    >
  },
  "dense_5" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.029110770672559738, -0.008371036499738693, -0.02050478383898735, -0.033897027373313904, 0.0027992427349090576, -0.03036465309560299, -0.014069393277168274, -0.025496412068605423, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.014254692941904068, 0.07411962002515793, 0.04338942468166351, -0.0642935261130333, 0.03273622691631317, -0.09052415192127228, -0.12970858812332153, ...],
        ...
      ]
    >
  },
  "dense_11" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.03261579945683479, 0.003924164921045303, -0.013643702492117882, -0.022517675533890724, -0.01994054578244686, -0.016114424914121628, -0.005500431172549725, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.00269136275164783, -0.01807646080851555, 0.03980712220072746, 0.0516231544315815, 0.010472417809069157, 0.015428143553435802, ...],
        ...
      ]
    >
  },
  "dropout_21" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [3047494511, 3948892101]
    >
  },
  "normalization_18" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.8780226111412048, 0.9574083089828491, 1.164774775505066, 1.2807426452636719, 1.2348146438598633, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.0046830568462610245, 0.02824675291776657, 0.06250514090061188, 0.29412591457366943, ...]
    >
  },
  "dense_6" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.15515416860580444, 0.04569678008556366, -0.030762122943997383, -0.2362818717956543, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.012649457901716232, 0.006101831793785095, 0.0037688175216317177, ...],
        ...
      ]
    >
  },
  "normalization_15" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.8705834746360779, 0.9248753190040588, 0.8575374484062195, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.18332846462726593, 0.025834091007709503, ...]
    >
  },
  "normalization_23" => %{
    "scale" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.9016574025154114, 1.2897508144378662, ...]
    >,
    "shift" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.2510068118572235, ...]
    >
  },
  "dropout_24" => %{
    "key" => #Nx.Tensor<
      u32[2]
      EXLA.Backend
      [3831323645, ...]
    >
  },
  "dense_10" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [...]
    >,
    ...
  },
  "dense_13" => %{...},
  ...
}
{:ok, input} = MyGPT.text_to_token_ids(tokenizer, "I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and")
token_ids = MyGPT.generate_tokens(predict_fn, trained_model_state, input, 6)
{:ok, text1} = MyGPT.token_ids_to_text(tokenizer, token_ids)
token_ids = MyGPT.generate_tokens(predict_fn, params, input, 6)
{:ok, text2} = MyGPT.token_ids_to_text(tokenizer, token_ids)

IO.inspect(text1)
IO.inspect(text2)
"I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, andis height qualitiesoring ofburn"
"I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and\tcatch\tcatch\tcatch\tcatch\tcatch\tcatch"
"I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and\tcatch\tcatch\tcatch\tcatch\tcatch\tcatch"

At the beginning of the training, both the training and validation set losses sharply decrease, which is a sign that the model is learning. However, the training set loss continues to decrease past the second epoch, whereas the validation loss stagnates. This is a sign that the model is still learning, but it’s overfitting to the training set past epoch 2.

5.3 Decoding strategies to control randomness

Text generation strategies (also called decoding strategies) to generate more original text, using temperature scaling and top-k sampling to improve the prediction.

The generated token is selected at each generation step corresponding to the largest probability score among all tokens in the vocabulary. This means that the LLM will always generate the same outputs.

5.3.1 Temperature scaling

Temperature scaling is a technique that adds a probabilistic selection process to the next-token generation task.

defmodule DecodingStrategies do
  require Nx

  def multinomial(probabilities, num_samples, max_random_number \\ 1000) do
    seed = :rand.uniform(max_random_number)
    
    key = Nx.Random.key(seed)
    
    {random_values, _new_key} = Nx.Random.uniform(key, shape: {num_samples})

    cumulative_probs = Nx.cumulative_sum(probabilities, axis: 0)
    
    Enum.map(Nx.to_flat_list(random_values), fn value ->
      Enum.find_index(
        Nx.to_flat_list(cumulative_probs), fn prob -> prob >= value end
      )
    end)
  end

  def softmax_with_temperature(logits, temperature \\ 1.0) when temperature > 0 do
    scaled_logits = Nx.divide(logits, temperature)
      Axon.Layers.softmax(scaled_logits, axis: -1)
  end
end
{:module, DecodingStrategies, <<70, 79, 82, 49, 0, 0, 12, ...>>, {:softmax_with_temperature, 2}}
next_token_logits = 
  Nx.tensor([4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79])
#Nx.Tensor<
  f32[9]
  [4.510000228881836, 0.8899999856948853, -1.899999976158142, 6.75, 1.6299999952316284, -1.6200000047683716, -1.8899999856948853, 6.28000020980835, 1.7899999618530273]
>
next_token_id =
  next_token_logits
  |> Axon.Layers.softmax(axis: -1)
  |> DecodingStrategies.multinomial(1)
  |> Nx.tensor()
  |> Nx.new_axis(0)

IO.inspect(next_token_id, label: "next_token_id")
next_token_id: #Nx.Tensor<
  s64[1][1]
  [
    [3]
  ]
>
#Nx.Tensor<
  s64[1][1]
  [
    [3]
  ]
>

We can further control the distribution and selection process via a concept called temperature scaling. Temperature scaling is just a fancy description for dividing the logits by a number greater than 0.

Temperatures greater than 1 result in a more uniformly distributed token probabilities, and temperatures smaller than 1 will result in a more confident (sharper or more peaky) distributions.

Using a temperature of 1 is the same as not using any temperature scaling.

prob_5 = DecodingStrategies.softmax_with_temperature(next_token_logits, 5)
prob_1 = DecodingStrategies.softmax_with_temperature(next_token_logits, 1)
prob_0_1 = DecodingStrategies.softmax_with_temperature(next_token_logits, 0.1)

IO.inspect(prob_5, label: "Temperature 5")
IO.inspect(prob_1, label: "Temperature 1")
IO.inspect(prob_0_1, label: "Temperature 0.1")
Temperature 5: #Nx.Tensor<
  f32[9]
  [0.15464816987514496, 0.07497484236955643, 0.04291204735636711, 0.24205201864242554, 0.08693429082632065, 0.04538368061184883, 0.04299795627593994, 0.22033578157424927, 0.08976118266582489]
>
Temperature 1: #Nx.Tensor<
  f32[9]
  [0.06090706214308739, 0.0016312535153701901, 1.0019362525781617e-4, 0.5721200704574585, 0.0034190029837191105, 1.3256912643555552e-4, 1.0120050865225494e-4, 0.3575764000415802, 0.004012236371636391]
>
Temperature 0.1: #Nx.Tensor<
  f32[9]
  [1.8529870693395623e-10, 3.51893963265341e-26, 2.6890267327905793e-38, 0.9909866452217102, 5.756917559685691e-23, 4.422023081385532e-37, 2.971829782368292e-38, 0.009013324975967407, 2.8514262205604768e-22]
>
#Nx.Tensor<
  f32[9]
  [1.8529870693395623e-10, 3.51893963265341e-26, 2.6890267327905793e-38, 0.9909866452217102, 5.756917559685691e-23, 4.422023081385532e-37, 2.971829782368292e-38, 0.009013324975967407, 2.8514262205604768e-22]
>

Applying very small temperatures, such as 0.1, will result in sharper distributions such that the behavior of the multinomial function selects the most likely token almost 100% of the time, approaching the behavior of the argmax function. Likewise, a temperature of 5 results in a more uniform distribution where other tokens are selected more often. This can add more variety to the generated texts but also more often results in nonsensical text.

5.3.2 Top-k sampling

Top-k sampling, when combined with probabilistic sampling and temperature scaling, can improve the text generation results. In top-k sampling, we can restrict the sampled tokens to the top-k most likely tokens and exclude all other tokens from the selection process by masking their probability scores.

Using top-k sampling with k = 3, we focus on the three tokens associated with the highest logits and mask out all other tokens with negative infinity ( –inf ) before applying the softmax function. This results in a probability distribution with a probability value 0 assigned to all non-top-k tokens.

top_k = 3
{top_logits, top_pos} = Nx.top_k(next_token_logits, k: top_k)
{#Nx.Tensor<
   f32[3]
   [6.75, 6.28000020980835, 4.510000228881836]
 >,
 #Nx.Tensor<
   s64[3]
   [3, 7, 0]
 >}
min_index = Nx.argmin(top_logits)
neg_inf_tensor = Nx.broadcast(Nx.Constants.neg_infinity(), next_token_logits.shape)
next_token_logits[min_index]
new_logits = Nx.select(Nx.less(next_token_logits, top_logits[min_index]), neg_inf_tensor, next_token_logits)
#Nx.Tensor<
  f32[9]
  [4.510000228881836, -Inf, -Inf, 6.75, -Inf, -Inf, -Inf, 6.28000020980835, -Inf]
>
topk_probas = Axon.Activations.softmax(new_logits, axis: -1)
#Nx.Tensor<
  f32[9]
  [0.06148479878902435, 0.0, 0.0, 0.5775469541549683, 0.0, 0.0, 0.0, 0.3609682321548462, 0.0]
>

5.3.3 Modifying the text generation function

defmodule GPTModel do
  @gpt_config_124m gpt_config_124m
  def model(input_shape \\ {2, 4, 768}, opts \\ @gpt_config_124m) do
    Axon.input("sequence", shape: input_shape)
    |> embedding_block(opts)
    |> Axon.dropout(rate: opts[:drop_rate])
    |> transformer_blocks(12, opts)
    |> Transformer.Layers.normalization(opts)
    |> Axon.dense(opts[:vocab_size], use_bias: false)
  end

  def embedding_block(input, opts) do
    token_emb = Axon.embedding(input, opts[:vocab_size], opts[:emb_dim])
    pos_emb = Transformer.Layers.pos_embedding(input, opts[:context_length], opts[:emb_dim])

    Axon.add(token_emb, pos_emb)
  end

  def transformer_blocks(input, n_blocks, transformer_opts) do
    for _n_block <- 1..n_blocks, reduce: input do
      model_acc ->
        Transformer.Layers.block(model_acc, transformer_opts)
    end
  end

  def text_to_token_ids(tokenizer, texts) when is_list(texts) do
    token_ids_list =
      for text <- texts do
        {:ok, token_ids} = text_to_token_ids(tokenizer, text)
        token_ids
      end

    Nx.stack(token_ids_list, axis: 1) |> Nx.squeeze()
  end

  def text_to_token_ids(tokenizer, text) do
    {:ok, tokens} = Tiktoken.encode(tokenizer, text)
    {:ok, Nx.tensor(tokens, type: :s64) |> Nx.new_axis(0)}
  end

  def token_ids_to_text(tokenizer, token_ids) do
    tokens_ids = Nx.to_flat_list(token_ids)
    Tiktoken.decode(tokenizer, tokens_ids)
  end

  def generate_tokens(
        predict_fn,
        model_params,
        input,
        max_new_token,
        k \\ 0,
        temperature \\ 1
      ) do
    generate_tokens_impl(predict_fn, model_params, input, max_new_token, k, temperature)
  end

  def generate_tokens_with_model(
        model,
        model_params,
        input,
        max_new_token,
        k \\ 0,
        temperature \\ 1
      )

  def generate_tokens_with_model(
        model,
        model_params,
        input,
        max_new_token,
        k,
        temperature
      )
      when model_params == %{} do
    {init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
    template = Nx.template(Nx.shape(input), :s64)
    init_model_params = init_fn.(template, model_params)
    generate_tokens_impl(predict_fn, init_model_params, input, max_new_token, k, temperature)
  end

  def generate_tokens_with_model(
        model,
        model_params,
        input,
        max_new_token,
        k,
        temperature
      ) do
    {_init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
    generate_tokens_impl(predict_fn, model_params, input, max_new_token, k, temperature)
  end

  defp generate_tokens_impl(predict_fn, model_params, input, max_new_token, k, temperature) do
    for _new_token_index <- 1..max_new_token, reduce: input do
      input_acc ->
        logit = predict_fn.(model_params, input_acc)

        # Get last element of the vector.
        predicted_new_token =
          logit[[.., -1]]
          |> top_k(k)
          |> softmax_with_temperature(temperature)
          |> Nx.new_axis(0)

        Nx.concatenate([input_acc, predicted_new_token], axis: 1)
    end
  end

  defp multinomial(probabilities, num_samples, max_random_number \\ 1000) do
    seed = :rand.uniform(max_random_number)

    key = Nx.Random.key(seed)

    {random_values, _new_key} = Nx.Random.uniform(key, shape: {num_samples})

    cumulative_probs = Nx.cumulative_sum(probabilities, axis: -1) 

    Enum.map(Nx.to_flat_list(random_values), fn value ->
      Enum.find_index(
        Nx.to_flat_list(cumulative_probs),
        fn prob -> prob >= value end
      )
    end)
  end

  defp softmax_with_temperature(logits, temperature) when temperature < 0,
    do: Axon.Layers.softmax(logits, axis: -1) |> Nx.argmax(axis: -1)

  defp softmax_with_temperature(logits, temperature) when temperature > 0 do
    scaled_logits = Nx.divide(logits, temperature)
    Axon.Layers.softmax(scaled_logits, axis: -1)
    |> multinomial(1)
    |> Nx.tensor()
  end

  defp top_k(logits, k) when k == 0, do: logits

  defp top_k(logits, k) do
    {top_logits, _top_pos} = Nx.top_k(logits, k: k)
    min_index = Nx.reduce_min(top_logits)
    neg_inf_tensor = Nx.broadcast(Nx.Constants.neg_infinity(), logits.shape)
    Nx.select(Nx.less(logits, min_index), neg_inf_tensor, logits)
  end
end
{:module, GPTModel, <<70, 79, 82, 49, 0, 0, 36, ...>>, {:top_k, 2}}
{:ok, input} = GPTModel.text_to_token_ids(tokenizer, "I HAD always ")
token_ids = GPTModel.generate_tokens(predict_fn, trained_model_state, input, 6, 5, 0.1)
{:ok, text1} = GPTModel.token_ids_to_text(tokenizer, token_ids)
token_ids = MyGPT.generate_tokens(predict_fn, trained_model_state, input, 6)
{:ok, text2} = GPTModel.token_ids_to_text(tokenizer, token_ids)

IO.inspect(text1)
IO.inspect(text2)
"I HAD always ale, except,ishing's"
"I HAD always ale, except cheap moment,"
"I HAD always ale, except cheap moment,"

5.5 Loading pretrained weights from OpenAI

Generation is where things get even more exciting. In this section w will use the extremely popular GPT-2 model to generate text continuation.

Generation generally is an iterative process, where the model predicts the sentence token by token, adhering to some constraints. Again, we will make use of a higher-level API based on Nx.Serving.

{:ok, gpt2} = Bumblebee.load_model({:hf, "openai-community/gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai-community/gpt2"})

serving = Bumblebee.Text.generation(gpt2, tokenizer, generation_config)

text_input = Kino.Input.text("Text", default: "Yesterday, I was reading a book and")
text = Kino.Input.read(text_input)
Nx.Serving.run(serving, text)
%{
  results: [
    %{
      text: " be able to do that. I want to be able to do that. I want to be able",
      token_summary: %{input: 3, output: 20, padding: 0}
    }
  ]
}

There is also gpt2-medium and gpt2-large - heavier versions of the model with much more parameters.

Now we can go about training the model! First, we need to extract the Axon model and parameters from the Bumblebee model map:

%{model: model, params: params} = gpt2
%{
  spec: %Bumblebee.Text.Gpt2{
    architecture: :for_causal_language_modeling,
    vocab_size: 50257,
    max_positions: 1024,
    hidden_size: 768,
    num_blocks: 12,
    num_attention_heads: 12,
    intermediate_size: nil,
    activation: :gelu_approx_tanh,
    scale_attention_weights: true,
    dropout_rate: 0.1,
    embeddings_dropout_rate: 0.1,
    attention_dropout_rate: 0.1,
    classifier_dropout_rate: 0.1,
    layer_norm_epsilon: 1.0e-5,
    initializer_scale: 0.02,
    num_labels: 2,
    id_to_label: %{},
    use_cross_attention: false,
    pad_token_id: 50256
  },
  params: #Axon.ModelState<
    Parameters: 163037184 (652.15 MB)
    Trainable Parameters: 163037184 (652.15 MB)
    Trainable State: 0, (0 B)
  >,
  model: #Axon<
    inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
    outputs: "container_28"
    nodes: 705
  >
}
tokenizer =
      Bumblebee.configure(tokenizer,
        length: nil,
        pad_direction: :left,
        return_token_type_ids: false,
        return_length: true
      )
%Bumblebee.Text.PreTrainedTokenizer{
  native_tokenizer: #Tokenizers.Tokenizer<[
    vocab_size: 50257,
    byte_fallback: false,
    continuing_subword_prefix: "",
    dropout: nil,
    end_of_word_suffix: "",
    fuse_unk: false,
    model_type: "bpe",
    unk_token: nil
  ]>,
  type: :gpt2,
  special_tokens: %{eos: "<|endoftext|>", bos: "<|endoftext|>", unk: "<|endoftext|>"},
  additional_special_tokens: [],
  add_special_tokens: true,
  length: nil,
  pad_direction: :left,
  truncate_direction: :right,
  return_attention_mask: true,
  return_token_type_ids: false,
  return_special_tokens_mask: false,
  return_offsets: false,
  return_length: true,
  template_options: []
}

The Axon model actually outputs a map with :logits, :hidden_states, and :attentions. You can see this by using Axon.get_output_shape/2 with an input. This method symbolically executes the graph and gets the resulting shapes:

input = Bumblebee.apply_tokenizer(tokenizer, "I want to")
%{
  "attention_mask" => #Nx.Tensor<
    u32[1][3]
    EXLA.Backend
    [
      [1, 1, 1]
    ]
  >,
  "input_ids" => #Nx.Tensor<
    u32[1][3]
    EXLA.Backend
    [
      [40, 765, 284]
    ]
  >,
  "length" => #Nx.Tensor<
    s32[1]
    EXLA.Backend
    [3]
  >
}

For training, we only care about the :logits key, so we’ll extract that by attaching an Axon.nx/2 layer to the model:

gpt2_model = Axon.nx(model, &amp; &amp;1.logits)
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "nx_136"
  nodes: 706
>
{_init_fn, predict_fn} = Axon.build(gpt2_model)
{#Function<134.18331142/2 in Nx.Defn.Compiler.fun/2>,
 #Function<134.18331142/2 in Nx.Defn.Compiler.fun/2>}
result = predict_fn.(params, input)
#Nx.Tensor<
  f32[1][3][50257]
  EXLA.Backend
  [
    [
      [-39.308448791503906, -39.010066986083984, -41.837467193603516, -41.781246185302734, -40.84248352050781, -40.89142990112305, -38.62623596191406, -40.154056549072266, -38.097896575927734, -41.04249954223633, -40.9429931640625, -36.262168884277344, -37.39033889770508, -36.03800964355469, -38.52249526977539, -40.54604721069336, -39.718971252441406, -39.7431640625, -40.27290344238281, -40.314857482910156, -40.54868698120117, -41.00197219848633, -40.9098014831543, -40.914119720458984, -41.297733306884766, -37.69235610961914, -39.106632232666016, -41.460182189941406, -40.526241302490234, -40.43655014038086, -38.97370147705078, -41.32615661621094, -39.90999984741211, -40.565555572509766, -40.7227897644043, -40.8016471862793, -40.875083923339844, -40.86553955078125, -40.39710998535156, -40.221649169921875, -38.78817367553711, -40.58393096923828, -40.43303298950195, -40.767242431640625, -40.72999572753906, -40.78556442260742, -40.461753845214844, -41.084720611572266, -41.600372314453125, -41.25688552856445, ...],
      ...
    ]
  ]
>
defmodule Decoder do
  def prediction_to_id(prediction, token_ids, k \\ 0, temperature \\ 1) do
    predicted_new_token =
          prediction[[.., -1]]
          |> top_k(k)
          |> softmax_with_temperature(temperature)
          |> Nx.new_axis(0)

    Nx.concatenate([token_ids, predicted_new_token], axis: 1)
  end

  defp multinomial(probabilities, num_samples, max_random_number \\ 1000) do
    seed = :rand.uniform(max_random_number)

    key = Nx.Random.key(seed)

    {random_values, _new_key} = Nx.Random.uniform(key, shape: {num_samples})

    cumulative_probs = Nx.cumulative_sum(probabilities, axis: -1) 

    Enum.map(Nx.to_flat_list(random_values), fn value ->
      Enum.find_index(
        Nx.to_flat_list(cumulative_probs),
        fn prob -> prob >= value end
      )
    end)
  end

  defp softmax_with_temperature(logits, temperature) when temperature < 0,
    do: Axon.Layers.softmax(logits, axis: -1) |> Nx.argmax(axis: -1)

  defp softmax_with_temperature(logits, temperature) when temperature > 0 do
    scaled_logits = Nx.divide(logits, temperature)
    Axon.Layers.softmax(scaled_logits, axis: -1)
    |> multinomial(1)
    |> Nx.tensor()
  end

  defp top_k(logits, k) when k == 0, do: logits

  defp top_k(logits, k) do
    {top_logits, _top_pos} = Nx.top_k(logits, k: k)
    min_index = Nx.reduce_min(top_logits)
    neg_inf_tensor = Nx.broadcast(Nx.Constants.neg_infinity(), logits.shape)
    Nx.select(Nx.less(logits, min_index), neg_inf_tensor, logits)
  end
end
{:module, Decoder, <<70, 79, 82, 49, 0, 0, 17, ...>>, {:top_k, 2}}
%{"input_ids" =>  token_ids} = input
Bumblebee.Tokenizer.decode(tokenizer, token_ids)
["I want to"]
predicted_tokens = Decoder.prediction_to_id(result, token_ids)
#Nx.Tensor<
  s64[1][4]
  EXLA.Backend
  [
    [40, 765, 284, 7716]
  ]
>
Bumblebee.Tokenizer.decode(tokenizer, predicted_tokens)
["I want to generate"]

Here explains how to integrate weights (parameters) in Elixir.

Summary

  • When LLMs generate text, they output one token at a time.
  • By default, the next token is generated by converting the model outputs into probability scores and selecting the token from the vocabulary that corresponds to the highest probability score, which is known as “greedy decoding.”
  • Using probabilistic sampling and temperature scaling, we can influence the diversity and coherence of the generated text.
  • Training and validation set losses can be used to gauge the quality of text generated by LLM during training.
  • Pretraining an LLM involves changing its weights to minimize the training loss.
  • The training loop for LLMs itself is a standard procedure in deep learning, using a conventional cross entropy loss and AdamW optimizer.
  • Pretraining an LLM on a large text corpus is time- and resource-intensive, so we can load openly available weights as an alternative to pretraining the model ona large dataset ourselves.