Chapter 6: Fine-tuning for classification
Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:table_rex, "~> 3.1.1"},
  {:bumblebee, "~> 0.6.0"},
  {:explorer, "~> 0.7.1"},
  {:req, "~> 0.4.5"},
  {:kino_vega_lite, "~> 0.1.11"}
])
Nx.global_default_backend(EXLA.Backend)Introduction
{:ok, gpt2} = Bumblebee.load_model({:hf, "openai-community/gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai-community/gpt2"})
serving = Bumblebee.Text.generation(gpt2, tokenizer, generation_config)
text_input = Kino.Input.text("Text", default: "Yesterday, I was reading a book and")text = Kino.Input.read(text_input)
Nx.Serving.run(serving, text)%{
  results: [
    %{
      text: " I was thinking, \"What's going on here?\" I was thinking, \"What's going on",
      token_summary: %{input: 8, output: 20, padding: 0}
    }
  ]
}%{model: model, params: params} = gpt2
tokenizer =
      Bumblebee.configure(tokenizer,
        length: nil,
        pad_direction: :left,
        return_token_type_ids: false,
        return_length: true
      )
input = Bumblebee.apply_tokenizer(tokenizer, "I want to")
gpt2_model = Axon.nx(model, & &1.logits)
{_init_fn, predict_fn} = Axon.build(gpt2_model)
result = predict_fn.(params, input)#Nx.Tensor<
  f32[1][3][50257]
  EXLA.Backend
  [
    [
      [-39.308448791503906, -39.010066986083984, -41.837467193603516, -41.781246185302734, -40.84248352050781, -40.89142990112305, -38.62623596191406, -40.154056549072266, -38.097896575927734, -41.04249954223633, -40.9429931640625, -36.262168884277344, -37.39033889770508, -36.03800964355469, -38.52249526977539, -40.54604721069336, -39.718971252441406, -39.7431640625, -40.27290344238281, -40.314857482910156, -40.54868698120117, -41.00197219848633, -40.9098014831543, -40.914119720458984, -41.297733306884766, -37.69235610961914, -39.106632232666016, -41.460182189941406, -40.526241302490234, -40.43655014038086, -38.97370147705078, -41.32615661621094, -39.90999984741211, -40.565555572509766, -40.7227897644043, -40.8016471862793, -40.875083923339844, -40.86553955078125, -40.39710998535156, -40.221649169921875, -38.78817367553711, -40.58393096923828, -40.43303298950195, -40.767242431640625, -40.72999572753906, -40.78556442260742, -40.461753845214844, -41.084720611572266, -41.600372314453125, -41.25688552856445, ...],
      ...
    ]
  ]
>defmodule GPTModel do
  def text_to_token_ids(tokenizer, text) do
    tokenizer
    |> Bumblebee.configure(
        length: nil,
        pad_direction: :left,
        return_token_type_ids: false,
        return_length: true
      )
    |> Bumblebee.apply_tokenizer(text)
  end
  def token_ids_to_text(tokenizer, token_ids) do
    tokenizer
    |> Bumblebee.configure(
      length: nil,
      pad_direction: :left,
      return_token_type_ids: false,
      return_length: true
    )
    |> Bumblebee.Tokenizer.decode(token_ids)
    |> Enum.at(0)
  end
  def generate_text(model, params, tokenizer, text, max_new_token, k \\ 0, temperature \\ 1) do
    {_init_fn, predict_fn} = Axon.build(model)
    input = text_to_token_ids(tokenizer, text)
    %{"input_ids" => new_tokens_ids} = 
      for _new_token_index <- 1..max_new_token, reduce: input do
        %{"input_ids" => input, "attention_mask" => attention_mask, "length" => length} = full_input ->
          logit = predict_fn.(params, full_input)
  
          # Get last element of the vector.
          predicted_new_token =
            logit[[.., -1]]
            |> top_k(k)
            |> softmax_with_temperature(temperature)
            |> Nx.new_axis(0)
  
          input = Nx.concatenate([input, predicted_new_token], axis: 1)
          attention_mask = Nx.concatenate([attention_mask, Nx.tensor([[1]])], axis: 1)
          length = Nx.add(length, 1)
          %{"input_ids" => input, "attention_mask" => attention_mask, "length" => length}
      end
    token_ids_to_text(tokenizer, new_tokens_ids)
  end
  defp multinomial(probabilities, num_samples, max_random_number \\ 1000) do
    seed = :rand.uniform(max_random_number)
    key = Nx.Random.key(seed)
    {random_values, _new_key} = Nx.Random.uniform(key, shape: {num_samples})
    cumulative_probs = Nx.cumulative_sum(probabilities, axis: -1)
    Enum.map(Nx.to_flat_list(random_values), fn value ->
      Enum.find_index(
        Nx.to_flat_list(cumulative_probs),
        fn prob -> prob >= value end
      )
    end)
  end
  defp softmax_with_temperature(logits, temperature) when temperature < 0,
    do: Axon.Layers.softmax(logits, axis: -1) |> Nx.argmax(axis: -1)
  defp softmax_with_temperature(logits, temperature) when temperature > 0 do
    scaled_logits = Nx.divide(logits, temperature)
    Axon.Layers.softmax(scaled_logits, axis: -1)
    |> multinomial(1)
    |> Nx.tensor()
  end
  defp top_k(logits, k) when k == 0, do: logits
  defp top_k(logits, k) do
    {top_logits, _top_pos} = Nx.top_k(logits, k: k)
    min_index = Nx.reduce_min(top_logits)
    neg_inf_tensor = Nx.broadcast(Nx.Constants.neg_infinity(), logits.shape)
    Nx.select(Nx.less(logits, min_index), neg_inf_tensor, logits)
  end
end{:module, GPTModel, <<70, 79, 82, 49, 0, 0, 24, ...>>, {:top_k, 2}}GPTModel.generate_text(gpt2_model, params, tokenizer, "I want to", 15)"I want to generate snow kinds of unscientific updates. This might seem like a crazy"6.1 Different categories of fine-tuning
The most common ways to fine-tune language models are instruction fine-tuning and classification fine-tuning. Instruction fine-tuning involves training a language model on a set of tasks using specific instructions to improve its ability to understand and execute tasks described in natural language prompts.
In classification fine-tuning, the model is trained to recognize a specific set of class labels, such as “spam” and “not spam.”
The key point is that a classification fine-tuned model is restricted to predicting classes it has encountered during its training, it is easier to develop a specialized model than a generalist model that works well across various tasks.
Choosing the right approach
Instruction fine-tuning improves a model’s ability to understand and generate responses based on specific user instructions. Instruction fine-tuning is best suited for models that need to handle a variety of tasks based on complex user instructions, improving flexibility and interaction quality. Classification fine-tuning is ideal for projects requiring precise categorization of data into predefined classes, such as sentiment analysis or spam detection.
While instruction fine-tuning is more versatile, it demands larger datasets and greater computational resources to develop models proficient in various tasks. In contrast, classification fine-tuning requires less data and compute power, but its use is confined to the specific classes on which the model has been trained.
6.2 Preparing the dataset
require Explorer.DataFrame, as: DF
File.cd!(__DIR__)
:ok{:ok, data} = File.read("sms+spam+collection/SMSSpamCollection")
String.length(data) |> dbg
data = String.replace(data, "\"", "\\\"")
String.length(data) |> dbg
477203477550477550original_df =
  data
  |> DF.load_csv!(delimiter: "\t", header: false, eol_delimiter: "\n") 
  |> DF.rename(["labels", "text"])#Explorer.DataFrame<
  Polars[5574 x 2]
  labels string ["ham", "ham", "spam", "ham", "ham", ...]
  text string ["Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
   "Ok lar... Joking wif u oni...",
   "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
   "U dun say so early hor... U c already then say...",
   "Nah I don't think he goes to usf, he lives around here though", ...]
>df = DF.distinct(original_df)
#df = original_df#Explorer.DataFrame<
  Polars[5171 x 2]
  labels string ["ham", "ham", "spam", "ham", "ham", ...]
  text string ["Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...",
   "Ok lar... Joking wif u oni...",
   "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
   "U dun say so early hor... U c already then say...",
   "Nah I don't think he goes to usf, he lives around here though", ...]
>frec = Explorer.Series.frequencies( df["labels"])#Explorer.DataFrame<
  Polars[2 x 2]
  values string ["ham", "spam"]
  counts integer [4518, 653]
>For simplicity, and because we prefer a small dataset (which will facilitate faster fine-tuning of the LLM), to avoid imbalanced dataset, we choose to undersample the dataset to include the same size for every label.
num_spam = frec["counts"][1]
ham_df = DF.filter(df, labels == "ham")
spam_df = DF.filter(df, labels == "spam")
#ham_s = Explorer.Series.sample(ham_df, 10)
ham_df = DF.sample(ham_df, num_spam, seed: 103)#Explorer.DataFrame<
  Polars[653 x 2]
  labels string ["ham", "ham", "ham", "ham", "ham", ...]
  text string ["Nothing. Can...", "But i have to. I like to have love and arrange.",
   "Goodmorning, today i am late for 1hr.",
   "Die... I accidentally deleted e msg i suppose 2 put in e sim archive. Haiz... I so sad...",
   "I remain unconvinced that this isn't an elaborate test of my willpower", ...]
>df = DF.concat_rows([ham_df, spam_df]) |> DF.shuffle(seed: 103)#Explorer.DataFrame<
  Polars[1306 x 2]
  labels string ["ham", "ham", "spam", "ham", "spam", ...]
  text string ["It means u could not keep ur words.", "Got it..mail panren paru..",
   "URGENT! Your mobile No *********** WON a £2,000 Bonus Caller Prize on 02/06/03! This is the 2nd attempt to reach YOU! Call 09066362220 ASAP! BOX97N7QP, 150ppm",
   "Will you be here for food",
   "U are subscribed to the best Mobile Content Service in the UK for £3 per ten days until you send STOP to 83435. Helpline 08706091795.",
   ...]
>frec = Explorer.Series.frequencies( df["labels"])#Explorer.DataFrame<
  Polars[2 x 2]
  values string ["ham", "spam"]
  counts integer [653, 653]
>dataset = DF.mutate(df, labels: if(labels == "ham", do: 0.0, else: 1.0))#Explorer.DataFrame<
  Polars[1306 x 2]
  labels f64 [0.0, 0.0, 1.0, 0.0, 1.0, ...]
  text string ["It means u could not keep ur words.", "Got it..mail panren paru..",
   "URGENT! Your mobile No *********** WON a £2,000 Bonus Caller Prize on 02/06/03! This is the 2nd attempt to reach YOU! Call 09066362220 ASAP! BOX97N7QP, 150ppm",
   "Will you be here for food",
   "U are subscribed to the best Mobile Content Service in the UK for £3 per ten days until you send STOP to 83435. Helpline 08706091795.",
   ...]
>frec = Explorer.Series.frequencies(dataset["labels"])#Explorer.DataFrame<
  Polars[2 x 2]
  values f64 [0.0, 1.0]
  counts integer [653, 653]
>total = Explorer.Series.n_distinct(dataset["text"])
train_size = Float.round(0.7 * total) |> trunc()
val_size = Float.round(0.1 * total) |> trunc()
test_size = total - train_size - val_size
# Generate a list of random indexes
indexes = Enum.shuffle(0..(total - 1))[556, 509, 879, 443, 899, 1263, 226, 996, 997, 561, 249, 1159, 508, 55, 950, 286, 438, 366, 328,
 743, 468, 1002, 691, 195, 870, 142, 777, 951, 701, 148, 1120, 67, 847, 307, 1079, 619, 547, 831,
 517, 227, 1032, 243, 803, 54, 490, 263, 646, 979, 1111, 346, ...]# Separar los índices según el tamaño de cada conjunto
train_indexes = Enum.slice(indexes, 0, train_size)
val_indexes = Enum.slice(indexes, train_size, val_size)
test_indexes = Enum.slice(indexes, train_size + val_size, test_size)[324, 641, 108, 954, 1225, 1026, 379, 40, 306, 851, 629, 74, 537, 577, 502, 144, 1247, 539, 50,
 1230, 1269, 309, 549, 795, 533, 412, 512, 311, 928, 111, 189, 1020, 210, 1118, 105, 277, 118, 542,
 157, 200, 1074, 1216, 434, 13, 272, 409, 611, 752, 1023, 36, ...]# Crear los subsets usando `Explorer.DataFrame.slice/2`
train_df = Explorer.DataFrame.slice(dataset, train_indexes) |> IO.inspect()
val_df = Explorer.DataFrame.slice(dataset, val_indexes) |> IO.inspect()
test_df = Explorer.DataFrame.slice(dataset, test_indexes) |> IO.inspect()#Explorer.DataFrame<
  Polars[914 x 2]
  labels f64 [0.0, 1.0, 1.0, 1.0, 1.0, ...]
  text string ["I've been trying to reach him without success",
   "SMS AUCTION - A BRAND NEW Nokia 7250 is up 4 auction today! Auction is FREE 2 join & take part! Txt NOKIA to 86021 now!",
   "Thanks 4 your continued support Your question this week will enter u in2 our draw 4 £100 cash. Name the NEW US President? txt ans to 80082",
   "Dear Matthew please call 09063440451 from a landline, your complimentary 4*Lux Tenerife holiday or £1000 CASH await collection. ppm150 SAE T&Cs Box334 SK38XH.",
   "SMS. ac JSco: Energy is high, but u may not know where 2channel it. 2day ur leadership skills r strong. Psychic? Reply ANS w/question. End? Reply END JSCO",
   ...]
>
#Explorer.DataFrame<
  Polars[131 x 2]
  labels f64 [0.0, 1.0, 1.0, 1.0, 0.0, ...]
  text string ["Not to worry. I'm sure you'll get it.",
   "Spook up your mob with a Halloween collection of a logo & pic message plus a free eerie tone, txt CARD SPOOK to 8007 zed 08701417012150p per logo/pic",
   "SMS SERVICES For your inclusive text credits pls gotto www.comuk.net login 3qxj9 unsubscribe with STOP no extra charge help 08702840625 comuk.220cm2 9AE",
   "Get a brand new mobile phone by being an agent of The Mob! Plus loads more goodies! For more info just text MAT to 87021.",
   "Chinatown got porridge, claypot rice, yam cake, fishhead beehoon... Either we eat cheap den go cafe n tok or go nydc or somethin...",
   ...]
>
#Explorer.DataFrame<
  Polars[261 x 2]
  labels f64 [1.0, 1.0, 1.0, 1.0, 0.0, ...]
  text string ["Thanks for your Ringtone Order, Reference T91. You will be charged GBP 4 per week. You can unsubscribe at anytime by calling customer services on 09057039994",
   "YOU VE WON! Your 4* Costa Del Sol Holiday or £5000 await collection. Call 09050090044 Now toClaim. SAE, TC s, POBox334, Stockport, SK38xh, Cost£1.50/pm, Max10mins",
   "FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end",
   "Hungry gay guys feeling hungry and up 4 it, now. Call 08718730555 just 10p/min. To stop texts call 08712460324 (10p/min)",
   "Wish i were with you now!", ...]
>#Explorer.DataFrame<
  Polars[261 x 2]
  labels f64 [1.0, 1.0, 1.0, 1.0, 0.0, ...]
  text string ["Thanks for your Ringtone Order, Reference T91. You will be charged GBP 4 per week. You can unsubscribe at anytime by calling customer services on 09057039994",
   "YOU VE WON! Your 4* Costa Del Sol Holiday or £5000 await collection. Call 09050090044 Now toClaim. SAE, TC s, POBox334, Stockport, SK38xh, Cost£1.50/pm, Max10mins",
   "FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end",
   "Hungry gay guys feeling hungry and up 4 it, now. Call 08718730555 just 10p/min. To stop texts call 08712460324 (10p/min)",
   "Wish i were with you now!", ...]
>{:ok, train_csv} = DF.dump_csv(train_df)
{:ok, val_csv} = DF.dump_csv(val_df)
{:ok, test_csv} = DF.dump_csv(test_df){:ok,
 "labels,text\n1.0,\"Thanks for your Ringtone Order, Reference T91. You will be charged GBP 4 per week. You can unsubscribe at anytime by calling customer services on 09057039994\"\n1.0,\"YOU VE WON! Your 4* Costa Del Sol Holiday or £5000 await collection. Call 09050090044 Now toClaim. SAE, TC s, POBox334, Stockport, SK38xh, Cost£1.50/pm, Max10mins\"\n1.0,\"FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end\"\n1.0,\"Hungry gay guys feeling hungry and up 4 it, now. Call 08718730555 just 10p/min. To stop texts call 08712460324 (10p/min)\"\n0.0,Wish i were with you now!\n1.0,Filthy stories and GIRLS waiting for your\n1.0,\"Update_Now - Xmas Offer! Latest Motorola, SonyEricsson & Nokia & FREE Bluetooth! Double Mins & 1000 Txt on Orange. Call MobileUpd8 on 08000839402 or call2optout/F4Q=\"\n1.0,\"For ur chance to win £250 cash every wk TXT: PLAY to 83370. T's&C's www.music-trivia.net custcare 08715705022, 1x150p/wk.\"\n1.0,For sale - arsenal dartboard. Good condition but no doubles or trebles!\n1.0,\"You are being contacted by our Dating Service by someone you know! To find out who it is, call from your mobile or landline 09064017305 PoBox75LDNS7 \"\n1.0,goldviking (29/M) is inviting you to be his friend. Reply YES-762 or NO-762 See him: www.SMS.ac/u/goldviking STOP? Send STOP FRND to 62468\n0.0,Yar he quite clever but aft many guesses lor. He got ask me 2 bring but i thk darren not so willing 2 go. Aiya they thk leona still not attach wat.\n0.0,Is it your yahoo boys that bring in the perf? Or legal.\n1.0,You have WON a guaranteed £1000 cash or a £2000 prize. To claim yr prize call our customer service representative on 08714712412 between 10am-7pm Cost 10p\n1.0,Reply to win £100 weekly! What professional sport does Tiger Woods play? Send STOP to 87239 to end service\n1.0,Reminder: You have not downloaded the content you have already paid for. Goto http://doit. mymoby. tv/ to collect your content.\n0.0,What about this one then.\n0.0,I am taking half day leave bec i am not well\n1.0,74355 XMAS iscoming & ur awarded either £500 CD gift vouchers & free entry 2 r £100 weekly draw txt MUSIC to 87066 TnC\n0.0,I think that tantrum's finished so yeah I'll be by at some point\n1.0,\"Customer service announcement. We recently tried to make a delivery to you but were unable to do so, please call 07099833605 to re-schedule. Ref:9280114\"\n1.0,\"0A$NETWORKS allow companies to bill for SMS, so they are responsible for their \\\"\"suppliers\\\"\", just as a shop has to give a guarantee on what they sell. B. G.\"\n0.0,Just chill for another 6hrs. If you could sleep the pain is not a surgical emergency so see how it unfolds. Okay\n1.0,We tried to contact you re your reply to our offer of a Video Phone 750 anytime any network mins Half Price Line Rental Camcorder Reply or call 08000930705\n0.0,Yes :)it completely in out of form:)clark also utter waste.\n0.0,\"Evening * v good if somewhat event laden. Will fill you in, don't you worry … Head * ok but throat * wrecked. See you at six then!\"\n0.0,\"Just gettin a bit arty with my collages at the mo, well tryin 2 ne way! Got a roast in a min lovely i shall enjoy that!\"\n0.0,That seems unnecessarily affectionate\n0.0,Mm feeling sleepy. today itself i shall get that dear\n1.0,I am hot n horny and willing I live local to you - text a reply to hear strt back from me 150p per msg Netcollex LtdHelpDesk: 02085076972 reply Stop to end\n0.0,I can take you at like noon\n1.0,FREE for 1st week! No1 Nokia tone 4 ur mobile every week just txt NOKIA to 8077 Get txting and tell ur mates. www.getzed.co.uk POBox 36504 W45WQ 16+ norm150p/tone\n0.0,\"Yeah, give me a call if you've got a minute\"\n0.0,How do friends help us in problems? They give the most stupid suggestion that Lands us into another problem and helps us forgt the previous problem\n1.0,\"Download as many ringtones as u like no restrictions, 1000s 2 choose. U can even send 2 yr buddys. Txt Sir to 80082 £3 \"\n0.0,As in i want custom officer discount oh.\n0.0,U dun say so early hor... U c already then say...\n0.0,Ok." <> ...}File.write!("sms+spam+collection/train.csv", train_csv)
File.write!("sms+spam+collection/val.csv", val_csv)
File.write!("sms+spam+collection/test.csv", test_csv):ok6.3 Creating data loaders
Previously, we utilized a sliding window technique to generate uniformly sized text chunks, which we then grouped into batches for more efficient model training. Each chunk functioned as an individual training instance.
We are now working with a spam dataset that contains text messages of varying lengths. To batch these messages as we did with the text chunks, we have two primary options:
- Truncate all messages to the length of the shortest message in the dataset or batch. This option is computationally cheaper, but it may result in significant information loss if shorter messages are much smaller than the average or longest messages, pot reducing model performance.
- Pad all messages to the length of the longest message in the dataset or batch.
To implement batching, where all messages are padded to the length of the longest message in the dataset, we add padding tokens to all shorter messages. For this purpose, we use “<|endoftext|>” as a padding token.
Instead of appending the string “<|endoftext|>” to each of the text messages directly, we can add the token ID corresponding to “<|endoftext|>”
Bumblebee.Tokenizer.all_special_tokens(tokenizer)["<|endoftext|>", "<|endoftext|>", "<|endoftext|>"]Bumblebee.Tokenizer.id_to_token(tokenizer, 50256)"<|endoftext|>"Bumblebee.apply_tokenizer(tokenizer, "<|endoftext|>")%{
  "attention_mask" => #Nx.Tensor<
    u32[1][1]
    EXLA.Backend
    [
      [1]
    ]
  >,
  "input_ids" => #Nx.Tensor<
    u32[1][1]
    EXLA.Backend
    [
      [50256]
    ]
  >,
  "length" => #Nx.Tensor<
    s32[1]
    EXLA.Backend
    [1]
  >
}# Identify the longest sequence
text_to_length =
  fn text -> 
    Bumblebee.apply_tokenizer(tokenizer, text)["input_ids"] |> Nx.size()
  end
text_length_series = Explorer.Series.transform(dataset["text"], &text_to_length.(&1))
Explorer.Series.argmax(text_length_series) |> dbg
max_length = Explorer.Series.max(text_length_series)962204text_to_length.(dataset["text"][962])204IO.inspect(dataset["text"][962])
Bumblebee.apply_tokenizer(tokenizer, dataset["text"][962])["input_ids"]"The last thing i ever wanted to do was hurt you. And i didn't think it would have. You'd laugh, be embarassed, delete the tag and keep going. But as far as i knew, it wasn't even up. The fact that you even felt like i would do it to hurt you shows you really don't know me at all. It was messy wednesday, but it wasn't bad. The problem i have with it is you HAVE the time to clean it, but you choose not to. You skype, you take pictures, you sleep, you want to go out. I don't mind a few things here and there, but when you don't make the bed, when you throw laundry on top of it, when i can't have a friend in the house because i'm embarassed that there's underwear and bras strewn on the bed, pillows on the floor, that's something else. You used to be good about at least making the bed."#Nx.Tensor<
  u32[1][204]
  EXLA.Backend
  [
    [464, 938, 1517, 1312, 1683, 2227, 284, 466, 373, 5938, 345, 13, 843, 1312, 1422, 470, 892, 340, 561, 423, 13, 921, 1549, 6487, 11, 307, 4072, 283, 21390, 11, 12233, 262, 7621, 290, 1394, 1016, 13, 887, 355, 1290, 355, 1312, 2993, 11, 340, 2492, 470, 772, 510, 13, ...]
  ]
>Nx.pad(Nx.tensor([[1, 2, 3]]), 0, [{0, 0, 0}, {0, 5, 0}])#Nx.Tensor<
  s32[1][8]
  EXLA.Backend
  [
    [1, 2, 3, 0, 0, 0, 0, 0]
  ]
>defmodule MyGPT.Classifier.Dataset do
  def build(csv_file, tokenizer, max_length \\ nil, pad_token_id \\ 50256, batch_size \\ 2) do
    {:ok, data} = File.read(csv_file)
    df = DF.load_csv!(data)
    max_length = compute_max_length(df, tokenizer, max_length)
    df_stream = DF.to_rows_stream(df)
    %{inputs: inputs, labels: labels} =
      for df_elem <- df_stream, reduce: %{inputs: [], labels: []} do
        %{inputs: inputs, labels: labels} = acc ->
          input = 
            Bumblebee.apply_tokenizer(tokenizer, df_elem["text"]) 
            |> pad_sample(pad_token_id, max_length)
          
          label = Nx.tensor(df_elem["labels"]) |> Nx.as_type({:s, 32}) |> Nx.new_axis(0)
          %{acc | inputs: inputs ++ [input], labels: labels ++ [label]}
      end
    
    nx_inputs =
      inputs
      |> Nx.stack()
      |> Nx.to_batched(batch_size)
      |> Stream.map(&(%{"input_ids" => &1}))
    nx_labels =
      labels
      |> Nx.stack()
      |> Nx.to_batched(batch_size)
    
    Stream.zip(nx_inputs, nx_labels)
  end
  def pad_sample(tokenizer_output, pad_token_id, max_length) do
    input_length = Nx.size(tokenizer_output["input_ids"])
    length_diff = max_length - input_length
    Nx.pad(tokenizer_output["input_ids"], pad_token_id, [{0, 0, 0}, {0, length_diff, 0}])
    |> Nx.reshape({max_length})
  end
  defp compute_max_length(df, tokenizer, nil) do
    text_length_series = Explorer.Series.transform(df["text"], &text_to_length(tokenizer, &1))
    Explorer.Series.max(text_length_series) + 1
  end 
  defp compute_max_length(_df, _tokenizer, max_length), do: max_length + 1
  defp text_to_length(tokenizer, text) do
    tokenizer
    |> Bumblebee.apply_tokenizer(text)
    |> Map.fetch!("input_ids")
    |> Nx.size()
  end
end
alias MyGPT.ClassifierMyGPT.Classifiertraining_dataset = Classifier.Dataset.build("sms+spam+collection/train.csv", tokenizer, 204, 50256, 8)
validation_dataset = Classifier.Dataset.build("sms+spam+collection/val.csv", tokenizer, 204, 50256, 8)
test_dataset = Classifier.Dataset.build("sms+spam+collection/test.csv", tokenizer, 204, 50256, 8)#Function<73.53678557/2 in Stream.zip_with/2>Enum.at(validation_dataset, 5){%{
   "input_ids" => #Nx.Tensor<
     u32[8][205]
     EXLA.Backend
     [
       [40, 760, 257, 1178, 661, 314, 460, 2277, 510, 290, 5089, 284, 262, 3763, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, ...],
       ...
     ]
   >
 },
 #Nx.Tensor<
   s32[8][1]
   EXLA.Backend
   [
     [0],
     [1],
     [0],
     [1],
     [1],
     [0],
     [0],
     [0]
   ]
 >}Enum.count(training_dataset) |> IO.inspect(label: "training batches")
Enum.count(validation_dataset) |> IO.inspect(label: "validation batches")
Enum.count(test_dataset) |> IO.inspect(label: "test batches")training batches: 115
validation batches: 17
test batches: 33336.4 Initializing a model with pretrained weights
We must prepare the model for classification fine-tuning to identify spam messages. To begin the model preparation process, we employ the same configurations we used to pretrain unlabeled data.
input_text = "Every effort moves you forward"
GPTModel.generate_text(gpt2_model, params, tokenizer, input_text, 15) |> IO.putsEvery effort moves you forward and you keep progressing. A wedding is something that's tough and you're:ok# Check if the model can classify by default:
input_text = """
Is the following text 'spam'? Answer with 'yes' or 'no':
'You are a winner you have been specially
selected to receive $1000 cash or a $2000 award.'
"""
GPTModel.generate_text(gpt2_model, params, tokenizer, input_text, 15) |> IO.putsIs the following text 'spam'? Answer with 'yes' or 'no':
'You are a winner you have been specially
selected to receive $1000 cash or a $2000 award.'
You've got the glory of winning this prize. You now feel as though:okThe model is struggling to follow instructions. This result is expected, as it has only undergone pretraining and lacks instruction fine-tuning. So, let’s prepare the model for classification fine-tuning.
6.5 Adding a classification head
We must modify the pretrained LLM to prepare it for classification fine-tuning. To do so, we replace the original output layer, which maps the hidden representation to a vocabulary of 50,257, with a smaller output layer that maps to two classes: 0 (“not spam”) and 1 (“spam”).
Here was discussed how to remove a layer form an Axon model.
{node, model_tail} = model |> Axon.pop_node() 
#test_model = Axon.nx(test_model, & &1.logits)
dbg(node)
{_init_fn, predict_fn} = Axon.build(model_tail.logits)
model_tail.logits |> dbg
result = predict_fn.(params, input)%Axon.Node{
  id: 870,
  name: #Function<194.98717769/2 in Axon.name/2>,
  mode: :both,
  parent: [867, 863, 868, 865, 869],
  parameters: [],
  args: [:layer, :layer, :layer, :layer, :layer],
  op: #Function<333.98717769/6 in Axon.restructure/2>,
  policy: #Axon.MixedPrecision.Policy<>,
  hooks: [],
  opts: [],
  global_options: [],
  op_name: :container,
  meta: nil,
  stacktrace: [
    {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 346]},
    {Bumblebee, :build_model, 2, [file: ~c"lib/bumblebee.ex", line: 366]},
    {Bumblebee, :load_model, 2, [file: ~c"lib/bumblebee.ex", line: 592]},
    {:elixir, :eval_external_handler, 3, [file: ~c"src/elixir.erl", line: 405]},
    {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 750]},
    {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 494]}
  ]
}#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "optional_185"
  nodes: 704
>#Nx.Tensor<
  f32[1][3][50257]
  EXLA.Backend
  [
    [
      [-39.308448791503906, -39.010066986083984, -41.837467193603516, -41.781246185302734, -40.84248352050781, -40.89142990112305, -38.62623596191406, -40.154056549072266, -38.097896575927734, -41.04249954223633, -40.9429931640625, -36.262168884277344, -37.39033889770508, -36.03800964355469, -38.52249526977539, -40.54604721069336, -39.718971252441406, -39.7431640625, -40.27290344238281, -40.314857482910156, -40.54868698120117, -41.00197219848633, -40.9098014831543, -40.914119720458984, -41.297733306884766, -37.69235610961914, -39.106632232666016, -41.460182189941406, -40.526241302490234, -40.43655014038086, -38.97370147705078, -41.32615661621094, -39.90999984741211, -40.565555572509766, -40.7227897644043, -40.8016471862793, -40.875083923339844, -40.86553955078125, -40.39710998535156, -40.221649169921875, -38.78817367553711, -40.58393096923828, -40.43303298950195, -40.767242431640625, -40.72999572753906, -40.78556442260742, -40.461753845214844, -41.084720611572266, -41.600372314453125, -41.25688552856445, ...],
      ...
    ]
  ]
>classifier_model = 
  model_tail.logits 
  |> Axon.pop_node() 
  |> elem(1)
  |> Enum.at(0)
  |> Axon.pop_node()
  |> elem(1)
  |> Enum.at(0)
  #|> Axon.freeze()
  #|> Axon.unfreeze(up: 90)
  |> Axon.dense(2)
  |> dbg
{init_fn, predict_fn} = Axon.build(classifier_model)
#classifier_model |> dbg
class_params = init_fn.(input, params)
result = predict_fn.(class_params, input)
20:44:37.749 [warning] found unexpected key in the initial parameters map: "language_modeling_head.output"
#Nx.Tensor<
  f32[1][3][2]
  EXLA.Backend
  [
    [
      [-0.27341315150260925, 4.142691612243652],
      [3.1811649799346924, 0.4369714558124542],
      [4.146763324737549, 4.505625247955322]
    ]
  ]
>{input, class_params, params}{%{
   "attention_mask" => #Nx.Tensor<
     u32[1][3]
     EXLA.Backend
     [
       [1, 1, 1]
     ]
   >,
   "input_ids" => #Nx.Tensor<
     u32[1][3]
     EXLA.Backend
     [
       [40, 765, 284]
     ]
   >,
   "length" => #Nx.Tensor<
     s32[1]
     EXLA.Backend
     [3]
   >
 },
 #Axon.ModelState<
   Parameters: 124441346 (497.77 MB)
   Trainable Parameters: 124441346 (497.77 MB)
   Trainable State: 0, (0 B)
 >,
 #Axon.ModelState<
   Parameters: 163037184 (652.15 MB)
   Trainable Parameters: 163037184 (652.15 MB)
   Trainable State: 0, (0 B)
 >}Fine-tuning selected layers vs. all layers
Since we start with a pretrained model, it’s not necessary to fine-tune all model layers. In neural network-based language models, the lower layers generally capture basic language structures and semantics applicable across a wide range of tasks and datasets. So, fine-tuning only the last layers (i.e., layers near the output), which are more specific to nuanced linguistic patterns and task-specific features, is often sufficient to adapt the model to new tasks. A nice side effect is that it is computationally more efficient to fine- tune only a small number of layers.
Fine-tuning additional layers can noticeably improve the predictive performance of the model. We also configure the last transformer block and the final LayerNorm module, which connects this block to the output layer, to be trainable.
frozen_params = Axon.ModelState.frozen_parameters(class_params)%{}trainable_params = Axon.ModelState.trainable_parameters(class_params)%{
  "decoder.blocks.7.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0015067235799506307, 0.011825020425021648, -0.015962932258844376, -0.009843949228525162, 0.017166582867503166, -0.0023309593088924885, 0.01716618798673153, -0.01785222440958023, 0.02480374090373516, -0.04544716328382492, -0.04233347252011299, 0.19252289831638336, 0.023084215819835663, 0.010606026276946068, 0.020862935110926628, -0.014164811000227928, 0.04407064616680145, -0.001367615768685937, 0.03417485952377319, 0.0010759946890175343, 0.020503727719187737, -0.018043607473373413, -0.00980320293456316, 0.028919916599988937, -0.003041631542146206, -0.041081931442022324, -0.030980169773101807, 0.0037019671872258186, 0.009434754960238934, -0.004209163598716259, 0.0016339614521712065, 0.02209661342203617, -0.014821838587522507, -0.02030492015182972, 0.03273279219865799, -0.04265918210148811, 0.00601721927523613, 0.00928163155913353, -0.028216741979122162, -0.007809154689311981, -0.03414953500032425, -0.011486138217151165, -0.006398558616638184, -0.014157279394567013, 0.010680632665753365, -0.06726251542568207, -0.0386282317340374, 0.007625897414982319, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.021402638405561447, 0.050967343151569366, 0.04545518010854721, -0.0109089445322752, -0.04193441942334175, -0.023326827213168144, 0.09840112924575806, 0.10209610313177109, 0.05283593386411667, 0.034540046006441116, 0.05829596519470215, -0.1514427661895752, 0.0445757620036602, -0.014143920503556728, 0.009306855499744415, 0.07754021883010864, 0.11058628559112549, -0.11392591893672943, 0.05147108808159828, 0.014121315442025661, 0.0954991802573204, 0.032277483493089676, 0.09881661087274551, 0.13257817924022675, -0.02898217923939228, -0.08612077683210373, -0.10402683913707733, 0.032209958881139755, 0.07464122027158737, -0.04770876094698906, -0.12873615324497223, 0.04697330668568611, -0.04240507632493973, -0.058180730789899826, 9.088746155612171e-4, 0.05110365152359009, 0.20511674880981445, -0.020681431517004967, -0.01085528265684843, -0.1648135483264923, 0.038136307150125504, -0.01178289670497179, -0.005067809019237757, -0.11831668764352798, 0.08509404212236404, -0.008230645209550858, 0.08802532404661179, ...],
        ...
      ]
    >
  },
  "decoder.blocks.0.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.045023269951343536, 0.03226305544376373, -0.08854541927576065, -0.008162474259734154, 0.11771649867296219, -0.2388763129711151, -0.0758744403719902, 0.029564158990979195, 0.04648246616125107, -0.0027773359324783087, -0.0038145091384649277, 0.020842736586928368, 0.009267736226320267, 0.03588847070932388, -0.1409134566783905, 0.06138546019792557, -0.06896162778139114, -0.012951075099408627, 0.06189567223191261, -0.017779534682631493, 0.021359939128160477, -0.05501718819141388, 0.02197490818798542, 0.0011829776922240853, 0.04943571239709854, 0.04192988574504852, 0.07189816981554031, -0.020239252597093582, 0.06509828567504883, 0.10963835567235947, 0.04001740366220474, 0.03582198917865753, 0.06442175805568695, -0.031552474945783615, 0.020920773968100548, -0.05902986228466034, 0.019308188930153847, -0.03758474066853523, 0.04784710705280304, 0.1294996589422226, 0.0656597912311554, 0.16015131771564484, 0.0036071811337023973, 0.013143805786967278, 0.08715710043907166, -0.05141830816864967, 0.10630246251821518, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.10660640895366669, 0.1527840495109558, 0.03310481086373329, 0.023700203746557236, 0.02993733063340187, -0.07244370877742767, -0.55028235912323, 0.10002829134464264, 0.10845957696437836, -0.06486132740974426, -0.050516318529844284, -0.04687585309147835, -0.019445782527327538, -0.031208522617816925, -0.1474572867155075, 0.00985053088515997, 0.14931969344615936, 0.07806558161973953, -9.190309210680425e-4, 0.23059925436973572, -0.021436616778373718, 0.04766186699271202, -0.07279295474290848, 0.06045343354344368, 0.013344594277441502, 0.16240741312503815, 0.11852358281612396, -0.02879641018807888, 0.05925842374563217, 0.12395480275154114, -0.09108705073595047, -0.01235614251345396, 0.03735384717583656, -0.022464079782366753, -0.045593757182359695, -0.27166813611984253, -0.04527653008699417, 0.03668154403567314, 0.08654572069644928, 0.023413583636283875, -0.08427132666110992, -0.03452301397919655, 0.03188594430685043, -0.043258294463157654, -0.05696522071957588, 0.018179383128881454, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.self_attention.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.020700350403785706, -0.26381751894950867, -0.1094379648566246, 0.2875203490257263, 0.2774174213409424, 0.04697851091623306, 0.062320489436388016, 0.025348320603370667, -0.1447840929031372, 0.23507045209407806, -0.22111418843269348, -0.04685695096850395, 0.04695908725261688, -0.08256081491708755, 0.049968790262937546, 0.0957123190164566, 0.008569050580263138, 0.36251145601272583, 0.026079699397087097, 0.14226137101650238, 0.30309706926345825, -0.051855895668268204, -0.041972529143095016, 0.045346442610025406, 0.09475845843553543, 0.03463638201355934, 0.11646518856287003, -0.13341329991817474, 0.10428164899349213, -0.16626955568790436, 0.008327489718794823, 0.11973997205495834, 0.1891246736049652, 0.27689802646636963, 0.07897856086492538, 0.2317073494195938, 3.121290064882487e-4, 0.014473449438810349, -0.22102972865104675, -0.024565676227211952, 0.0030907581094652414, -0.10433603823184967, 0.2853952944278717, 0.022545315325260162, 0.10485214740037918, 0.2518227994441986, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.013124971650540829, -0.06660282611846924, -0.0697135180234909, -0.030710719525814056, 0.042854610830545425, 0.008408921770751476, -0.03381262347102165, 0.06493787467479706, -0.03271748125553131, 0.12311036139726639, 0.1337415724992752, 0.10496947169303894, 0.0691475197672844, 0.015230960212647915, -0.07322822511196136, -0.05089348554611206, 0.09375757724046707, -0.13991476595401764, -0.05577455088496208, 0.04202907159924507, -0.028392035514116287, -0.07780732214450836, -0.05117521435022354, -0.010147273540496826, -0.09198098629713058, -0.15541940927505493, -0.09189142286777496, 0.04543245583772659, 0.05888330563902855, 0.007717895321547985, -0.055146679282188416, -0.10239534080028534, 0.01507049985229969, -0.01887754164636135, -0.06463737785816193, -0.10168686509132385, -0.04603245481848717, -0.03781770542263985, 0.08949298411607742, -0.014968409202992916, 0.01719915308058262, 0.01886896975338459, 0.040570300072431564, 0.052193399518728256, -0.04610079526901245, ...],
        ...
      ]
    >
  },
  "decoder.blocks.6.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.045025456696748734, 0.009542928077280521, 0.016986900940537453, 0.013239625841379166, 0.022165829315781593, -0.02997750975191593, -0.05318791791796684, 0.030291369184851646, 0.021082192659378052, 0.014847509562969208, -1.6169920513675606e-7, 0.0210530087351799, 0.056806933134794235, 0.009754392318427563, -0.06860926002264023, 0.03423593193292618, -0.02938588336110115, 0.015446092002093792, 0.010956122539937496, -0.02942276932299137, 0.03855720907449722, 0.01702086068689823, -0.005053986329585314, 0.017972847446799278, 0.014045481570065022, 0.021662695333361626, 0.02321026474237442, 0.04095667600631714, 0.0021895733661949635, 0.016517585143446922, -0.004257701337337494, -0.013539420440793037, -0.01618128828704357, -0.007974829524755478, -0.004954650532454252, -0.03999199718236923, -0.02632948011159897, -0.01814391277730465, 0.004013486206531525, 0.007407412398606539, 0.001718125189654529, 0.019525639712810516, 0.0014584340387955308, 0.022858886048197746, 5.235132412053645e-4, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.25878557562828064, 0.2482416182756424, 0.27244293689727783, 0.2566904127597809, 0.2704152762889862, 0.2627618908882141, 0.29589709639549255, 0.18097728490829468, 0.2705056071281433, 0.2587742209434509, 0.26464471220970154, 0.2627030909061432, 0.2763660252094269, 0.26602861285209656, 0.2792193293571472, 0.262694776058197, 0.2730119526386261, 0.2646535038948059, 0.24892735481262207, 0.24463066458702087, 0.2625494599342346, 0.2543585002422333, 0.2724483907222748, 0.23977135121822357, 0.2607770562171936, 0.25832539796829224, 0.2607426345348358, 0.2576335072517395, 0.2551948130130768, 0.2608013451099396, 0.2707134485244751, 0.2568177878856659, 0.2546869218349457, 0.24516521394252777, 0.2724584937095642, 0.27050986886024475, 0.22996947169303894, 0.2724611461162567, 0.2426939308643341, 0.24848425388336182, 0.2568362355232239, 0.2662496268749237, 0.2592116594314575, 0.2666124701499939, ...]
    >
  },
  "decoder.blocks.11.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.105723537504673, 0.11879222095012665, 0.006737913005053997, 0.05425174906849861, -0.020953308790922165, 0.00772686256095767, 0.05538606271147728, 0.05174220725893974, 0.02458597719669342, 0.05650372430682182, 0.07260715216398239, 0.015075696632266045, 0.07118107378482819, -0.02622274123132229, -0.004771863576024771, 0.13389313220977783, -0.04042479023337364, -0.09605908393859863, 0.05143333226442337, -0.0777275338768959, -0.038933414965867996, -0.01103443093597889, 0.22768071293830872, -0.051514096558094025, -0.09891602396965027, 0.0014073842903599143, -0.04144133999943733, 0.034522294998168945, -0.009654851630330086, 0.005805726628750563, 0.04466322436928749, -0.047319527715444565, -0.009854177013039589, 0.01743702031672001, 0.1440771073102951, -0.12892374396324158, 0.18401136994361877, -0.017713360488414764, -0.21362674236297607, 0.06416985392570496, 0.07322429120540619, 0.08671480417251587, 0.0764789879322052, 0.057795871049165726, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.06805934756994247, 0.09200220555067062, -0.03102017380297184, -0.20520317554473877, 0.06265048682689667, 0.17043784260749817, 0.037615444511175156, -0.24390925467014313, -0.06584632396697998, -0.36999693512916565, 0.19122055172920227, 0.07945571839809418, 0.008697953075170517, 0.19716542959213257, 0.17863412201404572, -0.01731661520898342, -0.3208216726779938, 0.21022184193134308, 0.080193892121315, 0.07145563513040543, -0.24923168122768402, -0.050284698605537415, -0.2518293559551239, 0.2810163199901581, -0.28573331236839294, 0.12340781837701797, -0.0859503522515297, 0.08418150991201401, -0.03493265062570572, 0.13855530321598053, -0.241315558552742, 0.15032462775707245, -0.25630655884742737, -0.03904435783624649, 0.0643249899148941, 0.26296404004096985, 0.44665995240211487, -0.001308600651100278, 0.16780276596546173, -0.14047899842262268, -0.03252283111214638, -0.3422131836414337, -0.15174603462219238, ...],
        ...
      ]
    >
  },
  "decoder.blocks.3.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.01611318439245224, 1.9620818784460425e-4, -0.033592935651540756, -0.014356587082147598, -0.010978053323924541, -0.005805983208119869, 0.0022555338218808174, 0.007051336579024792, 0.014428694732487202, -0.019858360290527344, -0.058242809027433395, -0.027789494022727013, -0.021741623058915138, 0.02190697006881237, -0.0035076746717095375, 0.002187394769862294, -0.022275205701589584, -0.009000930935144424, 0.0014367415569722652, -0.02442094311118126, -0.02785932831466198, -0.027073200792074203, 0.009673072025179863, -0.01975858025252819, -0.006040024105459452, -0.005454181227833033, -0.012051430530846119, 0.007628345862030983, -0.0019122350495308638, -0.031798556447029114, -0.03979314863681793, 0.0018334127962589264, 5.315942107699811e-4, -0.02355266362428665, 0.019198285415768623, -0.01209992729127407, 0.011508915573358536, 0.033759962767362595, 0.003140085143968463, -0.43169018626213074, -0.004638292361050844, 0.029909860342741013, -0.026687445119023323, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.02292151376605034, 0.0338389091193676, -0.047796934843063354, 0.09596998244524002, 0.045752089470624924, -0.030947083607316017, -0.06484285742044449, -0.021927157416939735, 0.08957719802856445, -0.06526286900043488, 0.048477549105882645, -0.007700455840677023, -0.09442861378192902, -0.13685542345046997, -0.10288318246603012, -0.09903551638126373, 0.17210477590560913, -0.014380788430571556, -0.081391841173172, -0.09592964500188828, -0.20048144459724426, -0.014696544967591763, 0.08353224396705627, 0.10173220932483673, -0.05892166867852211, -0.0928846076130867, 0.07746145129203796, 0.12805286049842834, -0.14344947040081024, 0.12638360261917114, 0.03233393654227257, 0.044042039662599564, -0.0408482551574707, -0.08512827754020691, 0.17795124650001526, -0.18981613218784332, -0.04125617817044258, 0.018192069604992867, 0.19313742220401764, -0.08462966978549957, 0.13086268305778503, 0.18106544017791748, ...],
        ...
      ]
    >
  },
  "decoder.blocks.3.self_attention.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.014248297549784184, -0.240375816822052, -0.018260201439261436, 0.11563233286142349, 0.18061132729053497, 0.04692136496305466, -0.009259623475372791, -0.05987045168876648, -0.10537739843130112, 0.10199987143278122, -0.10119312256574631, 0.054955385625362396, 0.08877909928560257, -0.034204259514808655, -0.030978145077824593, -0.14948132634162903, 0.013923843391239643, 0.22173760831356049, -0.044307757169008255, 0.11177702248096466, 0.25232768058776855, -0.011090553365647793, -0.02736673317849636, -0.05436551198363304, 0.035498980432748795, 0.06144917383790016, 0.10569152981042862, -0.06352245807647705, -0.014585292898118496, -0.08527898788452148, 0.06996647268533707, 0.15424810349941254, 0.2633742094039917, 0.07891955971717834, 0.10157502442598343, 0.15565523505210876, -0.05072539672255516, -0.015324545092880726, -0.13633786141872406, -0.0920839011669159, 0.02864094078540802, -0.1776469349861145, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.06602837890386581, -0.005861711222678423, -0.03592158108949661, -0.17321670055389404, -0.040567655116319656, -0.12914139032363892, -0.039053335785865784, 0.055556997656822205, -0.03793487325310707, -0.040671199560165405, -0.10253544896841049, -0.010256040841341019, -0.046774111688137054, -0.04418528825044632, -0.19765710830688477, 0.03096650168299675, 0.04684501513838768, -0.0827573910355568, 0.06590650230646133, 0.0626329854130745, -0.03907793387770653, -0.08920115977525711, 0.12399065494537354, 0.11888366937637329, 0.04018740355968475, -0.06762197613716125, -0.007350870408117771, 0.0679413229227066, -0.05717582628130913, -0.006823679432272911, -0.037478987127542496, -0.022401340305805206, -0.026354540139436722, -0.16326986253261566, 0.026277117431163788, -0.04454047232866287, -0.19500863552093506, -0.10626175254583359, 0.09581570327281952, 9.76486480794847e-4, 0.03719896450638771, ...],
        ...
      ]
    >
  },
  "decoder.blocks.7.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.003023461438715458, -0.0034503634087741375, -0.05022025108337402, 0.1445682793855667, -0.061383385211229324, -0.18525145947933197, 0.05382299795746803, 0.09621239453554153, 0.22644995152950287, 0.13868243992328644, -0.10454718768596649, -0.0445207841694355, 0.17207269370555878, 0.13456888496875763, -0.08831684291362762, -0.02446838468313217, -0.031328264623880386, 0.12339051067829132, 0.055783338844776154, -0.06627967208623886, 0.12010461091995239, -0.005991296377032995, -0.012475112453103065, -0.0091325668618083, -0.12514269351959229, -0.03106873109936714, -0.07901277393102646, 0.10679206997156143, 0.023926347494125366, 0.19130189716815948, -0.07669995725154877, -3.717075742315501e-4, -0.022119715809822083, -0.12196514755487442, -0.006166365463286638, -0.003968436270952225, 0.17010942101478577, -0.08083401620388031, -0.06381057947874069, 0.012830626219511032, 0.008812617510557175, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.02272615395486355, -0.03426380455493927, 0.10830746591091156, -0.0698532909154892, -0.04444117844104767, -0.1163613572716713, -0.10804831981658936, 0.027915891259908676, -0.06091436743736267, 0.11493780463933945, -0.12441886961460114, -0.12758924067020416, -0.08719192445278168, -0.08544804900884628, 0.01870325393974781, 0.06337859481573105, -0.05752550810575485, 0.08484397083520889, 0.10050169378519058, -8.142094011418521e-4, 0.04208269715309143, -0.0010115094482898712, 0.016743609681725502, -0.14341019093990326, -0.11040765047073364, -0.015847202390432358, 0.06745976209640503, 0.028722982853651047, 0.034336477518081665, -0.02375819906592369, -0.04028434678912163, 0.04014326259493828, -0.0828038826584816, 0.04731905460357666, -0.030171439051628113, 0.007872296497225761, -0.0554569847881794, -0.057010941207408905, -0.03682035207748413, 0.003613386768847704, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [3.5320030292496085e-4, -0.02352035790681839, 0.03579441457986832, -0.00819009356200695, -0.013341247104108334, 0.22570614516735077, 0.014364250004291534, 0.01957814209163189, -0.0029258071444928646, -0.005615453235805035, 0.006173497997224331, 0.03394452482461929, 0.04145542159676552, 0.02271554060280323, 0.027301738038659096, 0.023053526878356934, -0.030637608841061592, 0.0011183557799085975, 0.009655783884227276, -0.0031652075704187155, 0.004611491225659847, -8.690875256434083e-4, -0.0060260300524532795, -0.009716540575027466, 0.014348393306136131, -0.010869566351175308, -0.015734069049358368, 0.02538462169468403, -0.02972058765590191, -5.710391560569406e-4, 9.462513262405992e-4, 0.022110195830464363, 0.004091514740139246, 9.700193768367171e-4, 0.009012877009809017, -0.013626644387841225, 0.38099005818367004, 0.01451069489121437, -0.023601410910487175, 0.018054811283946037, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.05810696259140968, 0.06249383091926575, -0.16588400304317474, 0.09654977917671204, 0.09147700667381287, 0.048064541071653366, -0.09367276728153229, 0.014363005757331848, 0.10144679248332977, 0.0233471617102623, 0.006492226850241423, 0.024173308163881302, -0.08379942923784256, 0.035400889813899994, 0.011924143880605698, 0.06395914405584335, -0.025091668590903282, 0.23154862225055695, 0.04634566977620125, -0.03464064002037048, 0.019783668220043182, -0.09660620987415314, -0.08186054974794388, -0.18600694835186005, -0.11908779293298721, 0.21912004053592682, -0.09747152775526047, -0.010526714846491814, -0.015264575369656086, -0.12838749587535858, 0.061404332518577576, -0.07993568480014801, -0.016672959551215172, 0.007367889396846294, 0.06115908920764923, -0.00501360883936286, -0.009648752398788929, -0.11571478843688965, 0.022876223549246788, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0654396265745163, 0.10812345892190933, 0.04286130890250206, -0.2760961353778839, 0.08625604957342148, -0.2286910116672516, -0.04461458697915077, 7.275498937815428e-4, 0.10620112717151642, 0.10470309853553772, -0.08921228349208832, 0.07638224959373474, 0.1359548419713974, 0.011987983249127865, -0.13395175337791443, -0.012165425345301628, -0.02094808965921402, -0.24327068030834198, 0.19702652096748352, -0.11469811946153641, 0.06039329245686531, -0.0885312408208847, 0.03518642485141754, -0.032149139791727066, 0.051021404564380646, 0.05498753488063812, 0.08870009332895279, 0.06739293038845062, 0.0710294097661972, 0.13408461213111877, 0.029301894828677177, -0.00930361170321703, 0.025734776630997658, -0.24615506827831268, -0.08097236603498459, -0.06227843463420868, 0.005404133815318346, -0.11257304251194, -0.07373207062482834, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.02398330718278885, -0.20910729467868805, -0.029130421578884125, -0.014159053564071655, 0.12113775312900543, -0.02065293677151203, -0.11999223381280899, -0.269276887178421, 7.631823536939919e-4, 0.04650465399026871, -0.045628879219293594, -0.12281503528356552, 0.02982284128665924, 0.06332483142614365, -0.05124947428703308, 0.004213565960526466, 0.05839945375919342, 0.1387341022491455, 0.039438631385564804, -0.07732643932104111, 0.11901428550481796, -0.0635095089673996, 0.06668639183044434, 0.15541701018810272, -0.04893676936626434, -0.13186554610729218, 0.03706846386194229, 0.08605990558862686, -0.20225566625595093, -0.11350183933973312, 0.06946399807929993, -6.101589533500373e-4, -0.0847146287560463, -0.04157274588942528, 0.0049753207713365555, -0.015184903517365456, -0.16200494766235352, 0.04296805337071419, ...],
        ...
      ]
    >
  },
  "decoder.blocks.3.self_attention.key" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.09291772544384003, -0.04217880591750145, -0.020356424152851105, 0.04034513235092163, -0.013874712400138378, -0.11893210560083389, -0.05924098193645477, -0.027365664020180702, 0.05266174301505089, 0.008417871780693531, -0.0861397385597229, -0.019344178959727287, -0.06396947801113129, 0.046350132673978806, 0.05676386505365372, -5.086865712655708e-6, 0.10372238606214523, -0.21678526699543, -0.09768351167440414, -0.009167345240712166, 0.08033648878335953, 0.0944061353802681, 0.08438227325677872, 0.22630733251571655, -0.14517657458782196, -0.007395234890282154, -0.13697558641433716, -0.059589944779872894, 0.1007838323712349, 0.006305344868451357, -0.16571001708507538, 0.011307205073535442, 0.07076525688171387, -0.007749190088361502, -0.03437885269522667, 0.11359181255102158, -0.0808584913611412, 0.004542372655123472, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.07191380113363266, -0.08377204835414886, 0.15178976953029633, 0.2953050136566162, 0.2969595789909363, 0.017495974898338318, -0.2548777759075165, -0.10065208375453949, 0.23874078691005707, -0.013963986188173294, 0.016722165048122406, 0.29725781083106995, -0.05463255196809769, -0.26017990708351135, 0.042508628219366074, 0.14374642074108124, 0.3041999638080597, 0.3108743131160736, -0.03671329841017723, -0.14828334748744965, -0.36135849356651306, 0.13315536081790924, 0.002669523935765028, 0.342965304851532, -0.00918323453515768, 0.5080976486206055, -0.33983874320983887, 0.34413501620292664, 0.34606099128723145, -0.09834522753953934, 0.09321682155132294, 0.027324317023158073, -0.029650583863258362, -0.07398593425750732, 0.28067755699157715, 0.086429163813591, -0.3594377636909485, ...],
        ...
      ]
    >
  },
  "decoder.blocks.4.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.05897606164216995, 0.15732403099536896, -0.08401917666196823, -0.04356243461370468, -5.054371431469917e-4, -0.15412287414073944, -0.017296316102147102, 0.013370297849178314, 0.18167349696159363, 0.14976365864276886, -0.033649567514657974, -0.12340114265680313, 0.11543204635381699, 0.1179199367761612, -0.16670309007167816, 0.03197641670703888, -0.09725514054298401, -0.13734258711338043, 0.05221153423190117, -0.19777816534042358, -0.016869299113750458, -0.09240839630365372, -0.041709307581186295, 0.04953768476843834, -0.00948934257030487, -0.061544936150312424, 0.055073339492082596, 0.10132436454296112, 0.09503704309463501, 0.2767695188522339, -0.04178015887737274, -0.09199224412441254, -0.13981661200523376, -0.18259884417057037, -0.10355409979820251, -0.1465340107679367, 0.18813207745552063, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [-0.016589239239692688, 0.033356476575136185, -0.025944791734218597, -0.06458347290754318, -0.008617261424660683, -0.06050972640514374, 0.026528408750891685, -0.018711822107434273, 0.06334453076124191, -0.07016438990831375, -0.040474455803632736, -0.08900412917137146, -0.05415213108062744, 0.00924048200249672, -0.09849966317415237, 0.017269760370254517, 0.03402997925877571, -0.0723804235458374, 0.06440147757530212, 0.11115777492523193, -0.04786305129528046, 0.013474693521857262, 0.18404416739940643, -0.037795290350914, -0.1344190388917923, -0.041169364005327225, -0.02169586904346943, -0.09167039394378662, -0.047637294977903366, -0.010682673193514347, -0.017419137060642242, -0.09429912269115448, 0.09434686601161957, 0.10471426695585251, -0.0364200696349144, -0.13448311388492584, ...],
        ...
      ]
    >
  },
  "decoder.blocks.3.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.02195732854306698, 0.035705193877220154, 0.04064951837062836, -0.025932105258107185, 0.014330854639410973, -0.039580412209033966, -0.1155749261379242, 0.028433658182621002, 0.030217133462429047, -0.007304288912564516, 0.002263169502839446, 0.048969969153404236, 0.05460219457745552, 0.004443574231117964, -0.07364215701818466, 0.029881520196795464, -0.01567690446972847, -0.01119413785636425, 0.010527174919843674, -0.014802615158259869, 0.0057569448836147785, -0.020488295704126358, 0.024373553693294525, -0.02502831257879734, 0.019044460728764534, 0.02883622609078884, 0.056766364723443985, 0.021289346739649773, 0.003697927575558424, 0.03079884685575962, 0.0382881686091423, 0.011968711391091347, -0.011580651625990868, -0.015300154685974121, -6.058550206944346e-4, -0.08841429650783539, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.28222545981407166, 0.3131215572357178, 0.3034003973007202, 0.3393371105194092, 0.3154321610927582, 0.3309932053089142, 0.30264604091644287, 0.13720403611660004, 0.3059060573577881, 0.3251676857471466, 0.30764150619506836, 0.3226858973503113, 0.3249446153640747, 0.307298868894577, 0.28870701789855957, 0.3071806728839874, 0.2878659665584564, 0.3310546875, 0.29914483428001404, 0.2799876928329468, 0.350590318441391, 0.31099772453308105, 0.33301326632499695, 0.29779165983200073, 0.30957064032554626, 0.3072805106639862, 0.32713234424591064, 0.3132126033306122, 0.3168300688266754, 0.3154486119747162, 0.3116898834705353, 0.2880820631980896, 0.32278135418891907, 0.3193095326423645, 0.3212226927280426, ...]
    >
  },
  "decoder.blocks.9.self_attention.key" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.1255052387714386, 0.2206331193447113, 0.02155466005206108, -0.06883163750171661, -0.1582368165254593, -0.03266475349664688, 0.08761484920978546, -0.06262274831533432, 0.08110350370407104, -0.06443772464990616, -0.01988857425749302, -0.14905107021331787, 0.026666700839996338, 0.11045331507921219, 0.12034416198730469, -0.03482687100768089, -0.06882551312446594, -0.06385401636362076, 0.10434503108263016, -0.06925361603498459, 0.004499129019677639, 0.029036764055490494, 0.12643346190452576, 0.06261695176362991, 0.007506237365305424, -0.03166307136416435, 0.005090613849461079, 0.12981073558330536, 0.042653005570173264, -0.0895552709698677, -0.05183161050081253, -0.06341342628002167, 0.02652036026120186, -0.0019174101762473583, 0.06918130069971085, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.09414546191692352, -0.05349555239081383, -0.09394310414791107, 0.0680985152721405, 0.05514451488852501, 0.10338378697633743, -0.027529548853635788, 0.03738649934530258, 0.12017187476158142, 0.10639435052871704, -0.21103039383888245, 0.09546253830194473, 0.06053944304585457, 0.010116199031472206, 0.13734182715415955, -0.06583978235721588, -0.04140675812959671, 0.1631747931241989, -0.20215000212192535, 0.045266926288604736, 0.09429733455181122, 0.09323552995920181, -0.008121903985738754, 0.0823206678032875, 0.0871056392788887, -0.15799662470817566, 0.13381333649158478, -0.03152550011873245, -0.23124545812606812, 0.09309210628271103, -0.07485833019018173, 0.10970597714185715, 0.13201606273651123, 0.060820356011390686, ...],
        ...
      ]
    >
  },
  "decoder.blocks.8.ffn.intermediate" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.11210184544324875, -0.12243035435676575, -0.18696080148220062, -0.09358573704957962, -0.07174187898635864, -0.09109503775835037, -0.22353070974349976, -0.23647837340831757, -0.009258066304028034, -0.07567327469587326, -0.020744603127241135, 0.015013153664767742, -0.05726751312613487, -0.037477463483810425, -0.02229940891265869, -0.14848092198371887, -0.038814179599285126, 0.02283428981900215, -0.17796941101551056, -0.03959912434220314, 0.026215065270662308, -0.024661418050527573, -0.16912145912647247, -0.022491857409477234, -0.0367843434214592, -0.03149944916367531, -0.07467391341924667, -0.23281385004520416, -0.15074151754379272, -0.09347876906394958, 0.043657492846250534, -0.10828288644552231, 0.03156153857707977, -0.14233221113681793, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [0.07281597703695297, 0.23076365888118744, 0.056254733353853226, -0.007391960825771093, 0.37602195143699646, 0.026575006544589996, -0.0559421107172966, 0.24213534593582153, 0.17022763192653656, 0.008780462667346, 0.008517129346728325, -0.004474610555917025, 0.08918315172195435, -0.0734836533665657, -0.2731582224369049, -0.0711795762181282, -0.09748861193656921, 0.05456811189651489, -0.017072979360818863, 0.02093295007944107, -0.0573551245033741, -0.2868148684501648, -0.0835830420255661, -0.07690536230802536, -0.01711919531226158, -0.1358148157596588, -0.09262753278017044, -0.16049614548683167, 0.05776229128241539, -0.10948118567466736, -0.009235521778464317, -0.19156120717525482, -0.026457132771611214, ...],
        ...
      ]
    >
  },
  "decoder.blocks.9.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.05707442760467529, 0.035462647676467896, -0.0991213247179985, 0.13527709245681763, -0.028422338888049126, -0.025333665311336517, 0.009800232946872711, 0.0038071630988270044, 0.2587926983833313, -0.2757236957550049, -0.20594574511051178, -0.10678356885910034, 0.07663259655237198, 0.10082900524139404, 0.09312675893306732, 0.02456960454583168, -0.02361578680574894, 0.027035534381866455, -0.17550106346607208, -0.019421570003032684, 0.17456819117069244, -0.27237164974212646, -0.06578843295574188, 0.7469214200973511, -0.16401293873786926, -0.040414296090602875, 0.0813840925693512, 0.22788934409618378, 0.0898127555847168, -0.12558601796627045, -0.10621572285890579, -0.16889676451683044, 0.37665021419525146, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.07092145085334778, -0.1273377686738968, 0.18363283574581146, 0.0030052068177610636, -0.11931692808866501, 0.09142619371414185, 0.09066392481327057, -0.11962020397186279, -0.11546863615512848, -0.2486332207918167, -0.01928895153105259, -0.13211512565612793, 0.012719957157969475, -0.058824148029088974, -0.024515893310308456, -0.06520464271306992, -0.1823996901512146, 0.08205809444189072, -0.14308585226535797, -0.02228454314172268, 0.07376360148191452, 0.1657320261001587, 0.12012247741222382, -0.05113732069730759, 0.09645874798297882, -0.05014803633093834, 0.16555407643318176, -0.09963696449995041, -0.04661593958735466, -0.04942096397280693, 0.2155635803937912, -0.0188114196062088, ...],
        ...
      ]
    >
  },
  "decoder.blocks.6.self_attention.key" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.1274530589580536, -0.12059053033590317, -0.007807011716067791, 0.10612983256578445, -0.15291491150856018, 0.02204704098403454, -0.018509021028876305, -0.052053213119506836, -0.12142118811607361, 0.09243933856487274, -0.07746855169534683, -0.2270767092704773, 0.06773963570594788, 0.20025405287742615, 0.08527366816997528, 0.17828650772571564, -0.23929782211780548, 0.08052106946706772, -0.04288453236222267, -0.10904651135206223, -0.07623006403446198, 0.12980422377586365, -0.022937437519431114, 0.1890099197626114, 0.033789169043302536, -0.01818382926285267, -0.09066783636808395, 0.01310903113335371, -0.032641101628541946, 0.022848786786198616, -0.04326312616467476, -0.10271038115024567, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.23245984315872192, -0.051963526755571365, 0.03376476839184761, -0.17533107101917267, -0.0015645339153707027, -0.21227991580963135, -0.20261846482753754, 0.2514711320400238, 0.11193699389696121, 0.07263650000095367, 0.1105310469865799, 0.12566789984703064, -0.029312685132026672, -0.18712376058101654, -0.18661442399024963, -0.20233362913131714, -0.25359874963760376, -0.06487743556499481, -0.08297345787286758, 0.03883111849427223, 0.027809424325823784, 0.05575752630829811, 0.0707707554101944, -0.04197012260556221, -8.136677788570523e-4, 0.01103792805224657, -0.23597127199172974, 0.043430693447589874, 0.005747525487095118, -0.2847156822681427, -0.007437974214553833, ...],
        ...
      ]
    >
  },
  "decoder.blocks.10.self_attention_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.030510496348142624, 0.006001842673867941, 0.048768121749162674, 0.01165001466870308, 0.005879620555788279, 0.010598479770123959, -0.0672491267323494, 0.025556962937116623, 0.013904724270105362, 6.435069954022765e-4, 0.007581857033073902, 0.012246792204678059, 0.026725852862000465, 0.02789079211652279, 0.04797949641942978, -0.004149633459746838, 0.0409918837249279, 0.034725841134786606, -0.0010700526181608438, 0.01325390674173832, 0.03363116458058357, 0.023286549374461174, 0.02105352096259594, 0.025586696341633797, 0.020323842763900757, 0.006301491055637598, -0.0016431305557489395, 0.015145816840231419, 0.02847486175596714, 0.006544305011630058, 0.010149458423256874, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.426937460899353, 0.3798982501029968, 0.38925623893737793, 0.36285722255706787, 0.3721630275249481, 0.3487066626548767, 0.5020099878311157, 0.35842880606651306, 0.3820701539516449, 0.32065778970718384, 0.4148608148097992, 0.35868051648139954, 0.3574266731739044, 0.3837886452674866, 0.44628968834877014, 0.397741436958313, 0.3976704776287079, 0.3701581656932831, 0.3779294490814209, 0.39748460054397583, 0.3310501277446747, 0.3544250428676605, 0.38002175092697144, 0.41307657957077026, 0.379832923412323, 0.3876934051513672, 0.4171038269996643, 0.3594702482223511, 0.36118414998054504, 0.36034277081489563, ...]
    >
  },
  "decoder.blocks.4.ffn.intermediate" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.06427578628063202, -0.29860806465148926, -0.11274097114801407, -0.1695093810558319, -0.029283864423632622, 0.015383693389594555, -0.059798579663038254, -0.21982493996620178, -0.11109853535890579, -0.025041809305548668, -0.17913004755973816, -0.04023921862244606, -0.15255160629749298, -0.14760905504226685, -0.05613582208752632, -0.08194784075021744, -0.08518823236227036, -0.10574408620595932, -0.09536109864711761, 0.04650188237428665, -0.24487002193927765, -0.10550493001937866, -0.07493823766708374, 0.3118955194950104, -0.015493771992623806, 0.0036832429468631744, -0.0269177183508873, -0.16599039733409882, -0.09511981904506683, 0.006142540369182825, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-4.077995545230806e-4, -0.1200379952788353, -0.012310190126299858, -0.24376054108142853, 0.1328510195016861, 0.13179974257946014, 0.02635245770215988, 0.057356610894203186, -0.06828179955482483, -0.01686907559633255, 0.049044106155633926, -0.3784016966819763, -0.03531080484390259, 0.43567171692848206, 0.02976839430630207, -0.06014109030365944, 0.18706151843070984, -0.050236742943525314, 0.11668948084115982, 0.05957753583788872, -0.14054043591022491, -0.013522407039999962, -0.06838822364807129, 0.06603340059518814, 0.1704917997121811, -0.04796731844544411, 0.13489077985286713, 0.09338540583848953, -0.4719774127006531, ...],
        ...
      ]
    >
  },
  "decoder.blocks.0.self_attention.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.1502915918827057, -0.15426146984100342, -0.1466306895017624, -0.09912773221731186, 0.03380264714360237, -0.03444754704833031, -0.0706353709101677, -0.0936073362827301, 0.08110949397087097, 0.031160973012447357, -0.19926859438419342, -0.037245020270347595, 0.0030495061073452234, 0.04989158734679222, -0.0535597987473011, 0.0374077744781971, -0.19088934361934662, -0.08153925091028214, 0.0491158701479435, 0.14187365770339966, -0.11211564391851425, -0.09672139585018158, 0.05310625955462456, -0.21880236268043518, 0.09357757121324539, 0.012167288921773434, 0.012163301929831505, 0.08575751632452011, 0.04579753801226616, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.3127181828022003, -0.18741346895694733, 0.0980248749256134, -0.030339280143380165, -0.021640362218022346, -0.021060511469841003, -0.17938977479934692, -0.3298454284667969, 0.29462477564811707, 0.016695095226168633, -0.17451533675193787, -0.01856379583477974, 0.015662474557757378, 0.019278528168797493, 0.007864857092499733, 0.19694651663303375, -0.10614901781082153, -0.013053175993263721, 0.014888686127960682, 0.39647337794303894, -0.022225921973586082, 0.03043077513575554, -0.04131042957305908, -0.0028614092152565718, -0.04620683565735817, -0.0031064024660736322, -0.24714353680610657, -0.020426509901881218, ...],
        ...
      ]
    >
  },
  "decoder.blocks.4.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.030168335884809494, 0.10533975064754486, 0.15787462890148163, 0.024759536609053612, -0.006538494024425745, 0.060618676245212555, 0.08839995414018631, -0.13495968282222748, -0.020193777978420258, 0.12939752638339996, -0.08156191557645798, 0.34852737188339233, -0.25281667709350586, 0.0726567953824997, -0.00879716593772173, 0.46894973516464233, -0.20609457790851593, 0.09965331107378006, 0.16798031330108643, 0.21133244037628174, 0.16016334295272827, 0.023227756842970848, 0.2617943286895752, -0.2517806887626648, -0.27660995721817017, -0.010905023664236069, 0.04526042193174362, -0.01339393388479948, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.05423329025506973, 0.10368772596120834, -0.12026900798082352, 0.2241293340921402, 0.07213354855775833, -0.10475268959999084, -0.098123699426651, -0.03631935641169548, 0.3400250971317291, 0.11668488383293152, 0.3846721351146698, -0.04552015662193298, 0.03707210719585419, 0.060659412294626236, -0.13420982658863068, -0.06594616919755936, 0.17946170270442963, -0.10135768353939056, -8.781739161349833e-4, 0.03128969296813011, 0.10800420492887497, -0.20376046001911163, 0.036751069128513336, -0.06389692425727844, -0.15011754631996155, 0.2712489664554596, 0.14473587274551392, ...],
        ...
      ]
    >
  },
  "decoder.blocks.8.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.026478445157408714, 0.03058476373553276, 0.03277738764882088, 0.027175817638635635, -0.03420574590563774, -0.04291548207402229, -0.02005288004875183, 0.06033172458410263, 0.042357590049505234, 0.015196949243545532, -0.006372882053256035, 0.014285472221672535, 0.02985285595059395, 0.006489758379757404, -0.03699921444058418, 0.02677823044359684, -0.01978689804673195, 0.029895320534706116, 0.001117934938520193, -0.03488800302147865, 0.03796650841832161, -0.028332462534308434, -0.013404340483248234, 0.014422249048948288, -0.010553745552897453, 0.002345289569348097, -0.02829832397401333, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.256835401058197, 0.2554992139339447, 0.26428794860839844, 0.2509766221046448, 0.25293001532554626, 0.260743111371994, 0.3623034656047821, 0.2026355266571045, 0.2724631726741791, 0.23681232333183289, 0.26465314626693726, 0.2565948963165283, 0.2509452700614929, 0.26839280128479004, 0.28613126277923584, 0.25516241788864136, 0.26855340600013733, 0.2548842430114746, 0.25679123401641846, 0.25487300753593445, 0.23641103506088257, 0.2568313181400299, 0.2627745568752289, 0.25056999921798706, 0.25682833790779114, 0.2529260516166687, ...]
    >
  },
  "decoder.blocks.6.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.008094793185591698, -0.012891687452793121, 0.012476509436964989, 0.021194303408265114, 0.04015762358903885, -0.02663881704211235, -0.023752382025122643, -0.001071950071491301, 0.03309759870171547, 0.010285280644893646, -0.006889031268656254, -0.1257532835006714, 9.324409766122699e-4, 0.01171852182596922, -0.043873947113752365, -0.008764470927417278, -0.07170451432466507, -0.012096060439944267, 0.011502828449010849, 0.008124460466206074, -0.015051798895001411, 0.0037668568547815084, 0.028139527887105942, -0.04549719765782356, 0.01731819286942482, -0.04555899649858475, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.018813427537679672, 0.19313490390777588, 0.05417853593826294, 0.023874085396528244, -0.1469319462776184, -0.03817090019583702, -0.08952698111534119, 0.053204476833343506, 0.03557916358113289, -0.07985074818134308, -0.001789747504517436, -0.14399871230125427, 0.18803372979164124, 0.01738598197698593, 0.038275156170129776, 0.03143403306603432, -0.04185841605067253, 0.07224993407726288, 0.2724269926548004, 0.02323106862604618, 0.04056814685463905, 0.05048747733235359, 0.1290781944990158, -0.13238899409770966, -0.04559126868844032, ...],
        ...
      ]
    >
  },
  "decoder.blocks.3.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0232847947627306, 0.10563581436872482, -0.0531383752822876, -0.09454120695590973, -0.015325434505939484, -0.08506536483764648, -0.07138027250766754, 0.05388682335615158, 0.14111925661563873, 0.1348050981760025, -0.07036365568637848, -0.0036540981382131577, 0.041501376777887344, 0.07138317078351974, -0.1433209776878357, 0.033328648656606674, -0.025909310206770897, -0.15808121860027313, 0.058489684015512466, -0.1005123108625412, 0.12790188193321228, -0.05867855250835419, 0.05914581939578056, 0.037388164550065994, 0.04343689605593681, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.07660453766584396, 0.005029582418501377, -0.08185179531574249, -0.04126457870006561, 0.06943788379430771, -0.04354039579629898, 0.13279001414775848, 0.057913508266210556, -0.04257570207118988, -0.0920814573764801, -0.2037155032157898, -0.0021956718992441893, 0.06303318589925766, 0.0779477059841156, -0.028110455721616745, -0.22464191913604736, 0.08300850540399551, -0.037264134734869, -0.26715990900993347, 0.09652064740657806, 0.0018457196420058608, -7.214118377305567e-4, -0.09362386912107468, 0.05358770489692688, ...],
        ...
      ]
    >
  },
  "decoder.blocks.8.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.016666432842612267, -0.3908640146255493, -0.14189036190509796, -0.09550853818655014, 0.028818266466259956, -0.03564365208148956, 0.39664602279663086, -0.1003335490822792, -0.12153777480125427, 0.25413286685943604, -0.25467512011528015, 0.2728674113750458, -0.14818769693374634, 0.03464064374566078, -0.01881992444396019, -0.11672705411911011, 0.26542341709136963, -0.31608355045318604, -0.06016623601317406, -0.32036855816841125, 0.10017024725675583, 0.12484857439994812, -0.18118615448474884, -0.22473324835300446, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.04895172640681267, -0.13535946607589722, 0.08867701888084412, 0.03470441326498985, 0.16805638372898102, -0.04044034704566002, 0.02029840461909771, 0.12461696565151215, 0.038629818707704544, 0.0026324870996177197, 0.1111854836344719, -0.10957732796669006, 0.008841103874146938, 0.22030411660671234, 0.1880180835723877, -0.1760788857936859, 0.11645969748497009, 0.024209320545196533, 0.0647314041852951, 0.06860420852899551, -0.12365900725126266, -0.22124595940113068, -0.0675320103764534, ...],
        ...
      ]
    >
  },
  "decoder.blocks.9.ffn.intermediate" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.15674668550491333, -0.05856061354279518, -0.1862010508775711, -0.05574220418930054, 0.05308351293206215, -0.23877263069152832, -0.1725977510213852, -0.16667550802230835, -0.1204790323972702, 0.10092921555042267, -0.1463088095188141, -0.010244892910122871, -0.024020275101065636, -0.020772719755768776, -0.015557598322629929, -0.05010974407196045, -0.18073450028896332, 0.015407576225697994, -0.10474839061498642, -0.09000752866268158, -0.06217072531580925, -0.024300342425704002, -0.17624199390411377, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.1263384222984314, 0.10032892227172852, 0.04315217211842537, -0.34024444222450256, 0.03990277275443077, -0.1816384494304657, -0.08722347021102905, -0.09134645760059357, 0.08901876211166382, 0.2486005276441574, 0.14004628360271454, -0.02703682705760002, -0.0864061713218689, -0.011990259401500225, -0.07122871279716492, 0.14166268706321716, 0.13745108246803284, 0.09496646374464035, 0.19813163578510284, -0.06252028793096542, -0.08641546964645386, -0.19076146185398102, ...],
        ...
      ]
    >
  },
  "decoder.blocks.11.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.22222381830215454, 0.054906319826841354, 0.03307081758975983, 0.1580498218536377, 0.0302886962890625, -0.2349749058485031, -0.3064831495285034, -0.11609630286693573, 0.11942905932664871, -0.18294057250022888, 0.09135612845420837, 0.1770371049642563, -0.044169217348098755, -0.23199965059757233, -0.2659616470336914, -0.06414084881544113, 0.17821991443634033, 0.23306676745414734, 0.02620626427233219, -0.10767792165279388, 0.028602538630366325, 0.07664165645837784, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.2175447642803192, 0.060769811272621155, -0.06370636075735092, -0.010534803383052349, 0.03727848827838898, 0.038609497249126434, 0.0018609011312946677, 0.07256438583135605, -0.016396360471844673, -0.06539854407310486, 0.13125160336494446, -0.14517498016357422, 0.020752470940351486, 0.06105264276266098, 0.10187797248363495, -0.09049602597951889, 0.10861889272928238, 0.03869340196251869, 0.1436213254928589, 0.0036275412421673536, -0.13537921011447906, ...],
        ...
      ]
    >
  },
  "decoder.blocks.1.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.023383773863315582, 0.019925691187381744, -0.010859258472919464, 0.6510787010192871, -0.007127456832677126, 3.0829233583062887e-4, -0.0018755021737888455, 0.06549560278654099, -0.024767201393842697, 0.009818159975111485, 0.024395234882831573, 0.003666318953037262, 0.027201134711503983, -0.04317772388458252, -0.04861711338162422, 0.29467853903770447, 0.021457983180880547, 0.004829582292586565, -0.006358654238283634, 0.005354858003556728, -0.010436966083943844, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.009291327558457851, 0.030902646481990814, 0.077150858938694, 0.03817647323012352, -0.11329539120197296, -0.028751656413078308, -0.11431775987148285, 0.0115272281691432, -0.14699126780033112, -0.08838532119989395, -0.040005024522542953, -0.05282926931977272, -0.009505202062427998, -0.11530014872550964, 0.051036346703767776, -0.1743350625038147, -0.15030527114868164, -0.02026422880589962, -3.135435690637678e-4, 0.07486337423324585, ...],
        ...
      ]
    >
  },
  "decoder.blocks.11.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.05557013303041458, -0.023522645235061646, 0.020958656445145607, 0.023626001551747322, -0.012116808444261551, 0.10495120286941528, 0.0023480362724512815, -0.02615499310195446, 0.015563438646495342, -0.009612268768250942, -0.00985608622431755, -0.058930542320013046, -0.09239070862531662, 0.041288528591394424, 0.06829941272735596, -0.013641005381941795, 0.04428235813975334, 0.016503283753991127, -0.029918698593974113, -0.011992884799838066, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.09672704339027405, 0.02168218418955803, -0.04748263582587242, 0.16117015480995178, -0.09717914462089539, 0.07982426881790161, -0.025253381580114365, -0.07965120673179626, 0.10104228556156158, -0.03366006910800934, 0.04504484683275223, -0.14655034244060516, -0.04964844509959221, -0.021307766437530518, 0.10284372419118881, -0.035687148571014404, 0.1570502370595932, -0.09166368842124939, -0.009216838516294956, ...],
        ...
      ]
    >
  },
  "embedder.token_embedding" => %{
    "kernel" => #Nx.Tensor<
      f32[50257][768]
      EXLA.Backend
      [
        [-0.11010301113128662, -0.03926672413945198, 0.03310750797390938, 0.13382644951343536, -0.04847569391131401, -0.07891767472028732, -0.2397741675376892, -0.08947388082742691, 0.02525496669113636, -0.10739682614803314, -0.18114538490772247, -0.06715374439954758, 0.07391443103551865, -0.016131391748785973, 0.011662482284009457, 0.12449593096971512, -0.001963014481589198, -0.08150256425142288, 0.03377755731344223, ...],
        ...
      ]
    >
  },
  "decoder.blocks.4.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.03581799939274788, 0.0013955077156424522, 0.028343360871076584, 0.008547329343855381, 0.0017066916916519403, -0.015421940945088863, -0.10036948323249817, 0.007698511239141226, 0.02324959821999073, 0.009619422256946564, -0.007669593673199415, 0.025091584771871567, 0.04427342116832733, -0.020369691774249077, -0.05964217334985733, 0.019331300631165504, -0.01223753858357668, -0.01770605705678463, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.2568347752094269, 0.2607281506061554, 0.27832046151161194, 0.28114670515060425, 0.28516432642936707, 0.2770099639892578, 0.27500423789024353, 0.1547410935163498, 0.283526748418808, 0.28981050848960876, 0.2724590003490448, 0.27831974625587463, 0.28609699010849, 0.27203458547592163, 0.2650811970233917, 0.2724769115447998, 0.26392078399658203, ...]
    >
  },
  "decoder.blocks.10.self_attention.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.012379296123981476, 0.014348221942782402, -0.004452732391655445, -0.04507273808121681, -0.012405290268361568, 0.08091054856777191, 0.15559299290180206, 0.05403776094317436, 0.0675089880824089, 0.030687397345900536, 0.3024734854698181, -0.009818164631724358, 0.08117414265871048, -0.1720537543296814, 0.05229543149471283, 0.13546159863471985, -0.21003645658493042, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.007740591652691364, -0.5188016891479492, 0.07828357070684433, 0.16056594252586365, -0.345084011554718, 0.19546931982040405, -0.03769585117697716, 0.42103129625320435, 0.25166407227516174, -0.2463306337594986, 0.11552479863166809, -0.05760181322693825, 0.1186075359582901, -0.20914454758167267, -0.2030973732471466, -0.22620870172977448, ...],
        ...
      ]
    >
  },
  "decoder.blocks.0.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.04247827082872391, 0.03262673318386078, 0.004488067235797644, 0.015759378671646118, 0.005960179027169943, -0.026806578040122986, 0.01171385683119297, 0.05137127265334129, 0.013217959553003311, -0.004240347538143396, -0.011667933315038681, 0.0021937862038612366, -0.013367529958486557, -0.002064076717942953, -0.026502149179577827, 0.01191623043268919, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.13096605241298676, 0.20933687686920166, 0.20659242570400238, 1.254226803779602, 1.2637722492218018, 1.2695313692092896, 0.0935104489326477, 0.07934456318616867, 0.2260090857744217, 1.3007793426513672, 0.23237723112106323, 1.1525323390960693, 1.2761300802230835, 1.26953125, 0.7215696573257446, ...]
    >
  },
  "decoder.blocks.6.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.032783910632133484, 0.10676194727420807, -0.020738141611218452, 0.032967459410429, -0.015854258090257645, -0.11419513821601868, 0.018057847395539284, -0.058376479893922806, 0.19097566604614258, 0.11427830904722214, -0.06308886408805847, -0.047898005694150925, 0.17753615975379944, 0.06485942751169205, -0.11543454974889755, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.0969826877117157, -0.09540646523237228, 0.06500814855098724, 0.09873990714550018, -0.10905405879020691, 0.08737366646528244, 0.013654707930982113, -0.014028162695467472, 0.0809120461344719, -0.014392320066690445, -0.12919580936431885, -0.08840550482273102, 0.044951293617486954, -0.10458621382713318, ...],
        ...
      ]
    >
  },
  "decoder.blocks.7.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.37788376212120056, 0.07670742273330688, 0.0018559127347543836, -0.10767138749361038, -0.008797960355877876, -0.05189952999353409, -3.922691976185888e-4, 0.47561389207839966, 0.07581450045108795, -0.08861846476793289, 0.08389809727668762, 0.01737620122730732, 0.046795982867479324, 0.3026520609855652, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.09505656361579895, 0.2815342843532562, 0.07592008262872696, 0.05813293159008026, -0.08546139299869537, -0.12361529469490051, -0.1401478499174118, 0.058990344405174255, -0.2838239073753357, 0.11386465281248093, -0.1583637148141861, -0.050222549587488174, 0.10052284598350525, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.self_attention.key" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.1164763867855072, -0.10782725363969803, -0.0508851557970047, 0.006003555841743946, 0.08386743813753128, 0.16977491974830627, 0.07476542890071869, -0.03264027461409569, -0.02714858204126358, 0.025255294516682625, 0.07389767467975616, 0.14429283142089844, 0.056997258216142654, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [-0.009812610223889351, -0.17040055990219116, -0.034252941608428955, -0.2796790301799774, -0.015102369710803032, -0.04087914526462555, 0.007516289129853249, -0.021858975291252136, 0.020004870370030403, -0.01545007899403572, -0.008888366632163525, 0.013375785201787949, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.output_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.043191004544496536, 0.0245783943682909, 0.004377962090075016, -0.02795151062309742, 0.022012922912836075, -0.06287944316864014, -0.07565537840127945, 0.03899669274687767, 0.022937333211302757, -1.1479156455607153e-5, -0.014939102344214916, 0.022962642833590508, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.26464512944221497, 0.30566200613975525, 0.28360238671302795, 0.31934303045272827, 0.30372193455696106, 0.3557177186012268, 0.2701093852519989, 0.12454181909561157, 0.2917430102825165, 0.32704973220825195, 0.28982165455818176, ...]
    >
  },
  "decoder.blocks.7.self_attention_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.03026599995791912, 0.020864378660917282, 0.0317896232008934, 0.01288677379488945, 0.0020049354061484337, 0.007990947924554348, -0.026534350588917732, 0.02696160413324833, 0.011594080366194248, 0.013581011444330215, 0.009942992590367794, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.3551267981529236, 0.3564145565032959, 0.3758927881717682, 0.3657114505767822, 0.3599787950515747, 0.34165623784065247, 0.3870469033718109, 0.2665950655937195, 0.36816349625587463, 0.352022260427475, ...]
    >
  },
  "decoder.blocks.10.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.03012319840490818, 0.13602332770824432, -0.3841572105884552, 0.33693796396255493, -0.17340926826000214, -0.15652933716773987, 0.3452204167842865, -0.013852311298251152, -0.13725152611732483, -0.19260846078395844, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.21999560296535492, 0.2637237310409546, 0.04224083572626114, -0.06340529769659042, 0.12883184850215912, -0.01630588248372078, 0.10259885340929031, -0.038481637835502625, -0.07162255793809891, ...],
        ...
      ]
    >
  },
  "decoder.blocks.8.self_attention_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.026580655947327614, 0.014082363806664944, 0.03609975427389145, 0.014848507009446621, -0.004328620154410601, 0.009019272401928902, -0.0453052818775177, 0.022013459354639053, 0.006949983071535826, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.3517552316188812, 0.3381420969963074, 0.35644039511680603, 0.3355329632759094, 0.33988144993782043, 0.3217993676662445, 0.3803149163722992, 0.28222477436065674, ...]
    >
  },
  "decoder.blocks.11.ffn.intermediate" => %{
    "bias" => #Nx.Tensor<
      f32[3072]
      EXLA.Backend
      [-0.038226597011089325, -0.1637330800294876, 0.28211548924446106, -0.10262239724397659, 0.024077357724308968, -0.0832146555185318, -0.10605514049530029, -0.13395379483699799, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][3072]
      EXLA.Backend
      [
        [-0.14395315945148468, 0.21767818927764893, -0.08208920806646347, -0.08232471346855164, -0.0785398781299591, 0.2741723358631134, -0.1334928274154663, ...],
        ...
      ]
    >
  },
  "decoder.blocks.8.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.009447249583899975, -0.0026832041330635548, 0.03221642225980759, 0.012611282989382744, -0.0379859022796154, -0.11279989778995514, -0.008031549863517284, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.10462091118097305, 0.03825182095170021, -0.15017639100551605, 0.07689894735813141, -0.060953639447689056, -0.017569424584507942, ...],
        ...
      ]
    >
  },
  "decoder.blocks.10.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.03814335912466049, 0.025599177926778793, -0.20103679597377777, 0.14348195493221283, 0.17296327650547028, -0.1044880598783493, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[3072][768]
      EXLA.Backend
      [
        [0.05263395607471466, 0.2036895900964737, -0.22806401550769806, 0.0914945974946022, 0.10977628082036972, ...],
        ...
      ]
    >
  },
  "decoder.blocks.4.self_attention.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.13346901535987854, -0.12023670226335526, -0.01996723935008049, -0.017336692661046982, 0.06325311213731766, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.12204195559024811, -9.472257806919515e-4, 0.020139841362833977, -0.041395459324121475, ...],
        ...
      ]
    >
  },
  "decoder.blocks.5.self_attention.query" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.04358132556080818, 0.029470941051840782, 0.08498750627040863, -0.09758780896663666, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.18605762720108032, -0.19759854674339294, 0.0348566509783268, ...],
        ...
      ]
    >
  },
  "dense_72" => %{
    "bias" => #Nx.Tensor<
      f32[2]
      EXLA.Backend
      [0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][2]
      EXLA.Backend
      [
        [0.08618152141571045, -0.0649411529302597],
        ...
      ]
    >
  },
  "decoder.blocks.0.self_attention.value" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [-0.0034378955606371164, 0.006358778104186058, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[768][768]
      EXLA.Backend
      [
        [0.14213383197784424, ...],
        ...
      ]
    >
  },
  "decoder.blocks.2.self_attention_norm" => %{
    "beta" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [0.004916004370898008, ...]
    >,
    "gamma" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [...]
    >
  },
  "decoder.blocks.1.ffn.output" => %{
    "bias" => #Nx.Tensor<
      f32[768]
      EXLA.Backend
      [...]
    >,
    ...
  },
  "decoder.blocks.9.self_attention.output" => %{...},
  ...
}Map.keys(trainable_params) |> Enum.each(&IO.inspect(&1))"decoder.blocks.7.self_attention.value"
"decoder.blocks.0.ffn.output"
"decoder.blocks.2.self_attention.output"
"decoder.blocks.6.output_norm"
"decoder.blocks.11.ffn.output"
"decoder.blocks.3.self_attention.value"
"decoder.blocks.3.self_attention.output"
"decoder.blocks.7.ffn.output"
"decoder.blocks.2.self_attention.value"
"decoder.blocks.2.ffn.output"
"decoder.blocks.3.self_attention.key"
"decoder.blocks.4.ffn.output"
"decoder.blocks.3.output_norm"
"decoder.blocks.9.self_attention.key"
"decoder.blocks.8.ffn.intermediate"
"decoder.blocks.9.self_attention.query"
"decoder.blocks.6.self_attention.key"
"decoder.blocks.10.self_attention_norm"
"decoder.blocks.4.ffn.intermediate"
"decoder.blocks.0.self_attention.output"
"decoder.blocks.4.self_attention.query"
"decoder.blocks.8.output_norm"
"decoder.blocks.6.self_attention.value"
"decoder.blocks.3.ffn.output"
"decoder.blocks.8.self_attention.query"
"decoder.blocks.9.ffn.intermediate"
"decoder.blocks.11.self_attention.query"
"decoder.blocks.1.self_attention.value"
"decoder.blocks.11.self_attention.value"
"embedder.token_embedding"
"decoder.blocks.4.output_norm"
"decoder.blocks.10.self_attention.output"
"decoder.blocks.0.output_norm"
"decoder.blocks.6.ffn.output"
"decoder.blocks.7.self_attention.query"
"decoder.blocks.2.self_attention.key"
"decoder.blocks.2.output_norm"
"decoder.blocks.7.self_attention_norm"
"decoder.blocks.10.self_attention.query"
"decoder.blocks.8.self_attention_norm"
"decoder.blocks.11.ffn.intermediate"
"decoder.blocks.8.self_attention.value"
"decoder.blocks.10.ffn.output"
"decoder.blocks.4.self_attention.output"
"decoder.blocks.5.self_attention.query"
"dense_72"
"decoder.blocks.0.self_attention.value"
"decoder.blocks.2.self_attention_norm"
"decoder.blocks.1.ffn.output"
"decoder.blocks.9.self_attention.output"
"decoder.blocks.5.self_attention_norm"
"decoder.blocks.0.self_attention.key"
"decoder.blocks.6.self_attention.query"
"decoder.blocks.0.ffn.intermediate"
"decoder.blocks.5.ffn.output"
"norm"
"decoder.blocks.5.self_attention.output"
"decoder.blocks.5.output_norm"
"decoder.blocks.11.self_attention.key"
"decoder.blocks.7.output_norm"
"decoder.blocks.9.ffn.output"
"decoder.blocks.4.self_attention_norm"
"decoder.blocks.1.ffn.intermediate"
"embedder.position_embedding"
"decoder.blocks.3.ffn.intermediate"
"decoder.blocks.4.self_attention.value"
"decoder.blocks.7.ffn.intermediate"
"decoder.blocks.8.self_attention.output"
"decoder.blocks.4.self_attention.key"
"decoder.blocks.3.self_attention_norm"
"decoder.blocks.0.self_attention_norm"
"decoder.blocks.7.self_attention.output"
"decoder.blocks.1.self_attention_norm"
"decoder.blocks.8.self_attention.key"
"decoder.blocks.9.self_attention_norm"
"decoder.blocks.9.self_attention.value"
"decoder.blocks.5.self_attention.value"
"decoder.blocks.3.self_attention.query"
"decoder.blocks.6.ffn.intermediate"
"decoder.blocks.6.self_attention_norm"
"decoder.blocks.5.self_attention.key"
"decoder.blocks.11.output_norm"
"decoder.blocks.2.self_attention.query"
"decoder.blocks.5.ffn.intermediate"
"decoder.blocks.1.output_norm"
"decoder.blocks.1.self_attention.output"
"decoder.blocks.7.self_attention.key"
"decoder.blocks.10.self_attention.value"
"decoder.blocks.9.output_norm"
"decoder.blocks.1.self_attention.key"
"decoder.blocks.11.self_attention_norm"
"decoder.blocks.2.ffn.intermediate"
"decoder.blocks.11.self_attention.output"
"decoder.blocks.10.ffn.intermediate"
"decoder.blocks.8.ffn.output"
"decoder.blocks.6.self_attention.output"
"decoder.blocks.1.self_attention.query"
"decoder.blocks.0.self_attention.query"
"decoder.blocks.10.output_norm"
"decoder.blocks.10.self_attention.key":okfreeze_fn = fn
  ["decoder.blocks.11" <> _internal, _parameter] -> false
  ["dense_72", _parameter] -> false
  ["norm", _parameter] -> false
  _ -> true
end
classifier_params = Axon.ModelState.freeze(class_params, freeze_fn)#Axon.ModelState<
  Parameters: 124441346 (497.77 MB)
  Trainable Parameters: 7090946 (28.36 MB)
  Trainable State: 0, (0 B)
>classifier_params
|> Axon.ModelState.trainable_parameters()
|> Map.keys() 
|> Enum.each(&IO.inspect(&1))"decoder.blocks.11.ffn.intermediate"
"decoder.blocks.11.ffn.output"
"decoder.blocks.11.output_norm"
"decoder.blocks.11.self_attention.key"
"decoder.blocks.11.self_attention.output"
"decoder.blocks.11.self_attention.query"
"decoder.blocks.11.self_attention.value"
"decoder.blocks.11.self_attention_norm"
"dense_72"
"norm":okinput = Bumblebee.apply_tokenizer(tokenizer, "Do you have time") |> dbg
result = predict_fn.(classifier_params, input)%{
  "attention_mask" => #Nx.Tensor<
    u32[1][4]
    EXLA.Backend
    [
      [1, 1, 1, 1]
    ]
  >,
  "input_ids" => #Nx.Tensor<
    u32[1][4]
    EXLA.Backend
    [
      [5211, 345, 423, 640]
    ]
  >,
  "length" => #Nx.Tensor<
    s32[1]
    EXLA.Backend
    [4]
  >
}#Nx.Tensor<
  f32[1][4][2]
  EXLA.Backend
  [
    [
      [-0.7038883566856384, 3.8006432056427],
      [3.971379518508911, 5.477177619934082],
      [5.736504554748535, 0.6289554834365845],
      [1.911087155342102, 4.422253131866455]
    ]
  ]
>We will focus on the last row corresponding to the last output token. Since the causal attention mask restrics a token’s focus to its current position and the those before it, ensuring that each token can only be influenced by itself and the preceding tokens.
The last token in a sequence accumulates the most information since it is the only token with access to data from all the previous tokens.
result[[.., -1]]#Nx.Tensor<
  f32[1][2]
  EXLA.Backend
  [
    [1.911087155342102, 4.422253131866455]
  ]
>6.6 Calculating the classification loss and accuracy
result[[.., -1]]
|> Axon.Activations.softmax(axis: -1) 
|> Nx.argmax(axis: -1)#Nx.Tensor<
  s32[1]
  EXLA.Backend
  [1]
>Using the softmax function here is optional because the largest outputs directly correspond to the highest probability scores.
Axon.Losses.categorical_cross_entropy(Nx.tensor([[0]]), result[[.., -1, ..]], reduction: :mean,
    from_logits: true,
    sparse: true)#Nx.Tensor<
  f32
  EXLA.Backend
  2.5892131328582764
>loss_fn = fn y_true, y_pred ->
  Axon.Losses.categorical_cross_entropy(y_true, y_pred[[.., -1]],
    reduction: :mean,
    from_logits: true,
    sparse: true
  )
end
{loss_fn.(Nx.tensor([0]), result), loss_fn.(Nx.tensor([1]), result)}{#Nx.Tensor<
   f32
   EXLA.Backend
   2.5892131328582764
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend
   0.07804705947637558
 >}input_ids = input["input_ids"] |> Nx.reshape({4})
stacked_input = Nx.stack([input_ids, input_ids]) |> dbg 
res = predict_fn.(class_params, %{"input_ids" => stacked_input})#Nx.Tensor<
  u32[2][4]
  EXLA.Backend
  [
    [5211, 345, 423, 640],
    [5211, 345, 423, 640]
  ]
>#Nx.Tensor<
  f32[2][4][2]
  EXLA.Backend
  [
    [
      [-0.7038887739181519, 3.800645589828491],
      [3.9713797569274902, 5.477175712585449],
      [5.7365031242370605, 0.6289563775062561],
      [1.9110867977142334, 4.42225456237793]
    ],
    [
      [-0.7038887739181519, 3.800645589828491],
      [3.9713797569274902, 5.477175712585449],
      [5.736503601074219, 0.6289564371109009],
      [1.9110854864120483, 4.422255039215088]
    ]
  ]
>multi_loss_fn = fn y_true, y_pred ->
  logits_flat = Nx.flatten(y_pred, axes: [0, 1])
  targets_flat = Nx.flatten(y_true)
  Axon.Losses.categorical_cross_entropy(targets_flat, logits_flat,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )
end#Function<41.105768164/2 in :erl_eval.expr/6>loss_fn.(Nx.tensor([[1],[1]]), res)#Nx.Tensor<
  f32
  EXLA.Backend
  0.07804689556360245
>6.7 Fine-tuning the model on supervised data
optimizer = Polaris.Optimizers.adamw(learning_rate: 5.00e-5, decay: 0.1)
#data = [{input, Nx.tensor([1])}, {input, Nx.tensor([1])}]
trained_model_state =
  classifier_model
  |> Axon.Loop.trainer(loss_fn, optimizer)
  |> Axon.Loop.validate(classifier_model, validation_dataset)
  #|> Axon.Loop.run(training_dataset, classifier_params, epochs: 1, compiler: EXLA)
  |> Axon.Loop.run(training_dataset, classifier_params, epochs: 1, compiler: EXLA)
22:19:32.281 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 100, loss: 0.6848841
Batch: 16, loss: 0.4179240#Axon.ModelState<
  Parameters: 124441420 (497.77 MB)
  Trainable Parameters: 124441346 (497.77 MB)
  Trainable State: 74, (296 B)
>File.cd(__DIR__)
model_bin = Nx.serialize(trained_model_state)
model_erlang_binary = :erlang.term_to_binary(model_bin)
File.write!("classifier_gpt.axon", model_erlang_binary)
#Nx.deserialize(model_bin):okChoosing the number of epochs
The number of epochs depends on the dataset and the task’s difficulty, and there is no universal solution or recommendation, although an epoch number of five is usually a good starting point. If the model overfits after the first few epochs as a loss plot, you may need to reduce the number of epochs. Conversely, if the trendline suggests that the validation loss could improve with further training, you should increase the number of epochs. In this concrete case, five epochs is a reasonable number as there are no signs of early overfitting, and the validation loss is close to 0.
trained_model_state =
  classifier_model
  |> Axon.Loop.trainer(loss_fn, optimizer)
  |> Axon.Loop.validate(classifier_model, validation_dataset)
  #|> Axon.Loop.run(training_dataset, classifier_params, epochs: 1, compiler: EXLA)
  |> Axon.Loop.run(training_dataset, trained_model_state, epochs: 5, compiler: EXLA)
23:58:45.286 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 100, loss: 0.2878236
Epoch: 1, Batch: 100, loss: 0.1859646
Epoch: 2, Batch: 100, loss: 0.1639413
Epoch: 3, Batch: 100, loss: 0.1299669
Epoch: 4, Batch: 100, loss: 0.1079900
Batch: 16, loss: 0.1688834#Axon.ModelState<
  Parameters: 124441420 (497.77 MB)
  Trainable Parameters: 124441346 (497.77 MB)
  Trainable State: 74, (296 B)
>File.cd(__DIR__)
model_bin = Nx.serialize(trained_model_state)
model_erlang_binary = :erlang.term_to_binary(model_bin)
File.write!("classifier_gpt.axon", model_erlang_binary)
#Nx.deserialize(model_bin):ok{input, expected} = Enum.at(test_dataset, 10)
res = predict_fn.(trained_model_state, input)
res =
res[[.., -1]]
|> Axon.Activations.softmax(axis: -1) 
|> Nx.argmax(axis: -1)
|> Nx.new_axis(1)
{res, expected}{#Nx.Tensor<
   s32[8][1]
   EXLA.Backend
   [
     [0],
     [0],
     [0],
     [0],
     [0],
     [1],
     [1],
     [0]
   ]
 >,
 #Nx.Tensor<
   s32[8][1]
   EXLA.Backend
   [
     [0],
     [0],
     [0],
     [1],
     [0],
     [1],
     [1],
     [0]
   ]
 >}classifier_accuracy_fn = fn(y_true, prediction) -> 
  y_pred =
    prediction[[.., -1]]
    |> Axon.Activations.softmax(axis: -1) 
    |> Nx.argmax(axis: -1)
    |> Nx.new_axis(1)
  Axon.Metrics.accuracy(y_true, y_pred)
end#Function<41.105768164/2 in :erl_eval.expr/6>{test, value} = Enum.at(test_dataset, 10)
res = predict_fn.(trained_model_state, test)#Nx.Tensor<
  f32[8][205][2]
  EXLA.Backend
  [
    [
      [-0.39963483810424805, 2.7791032791137695],
      [-0.41462546586990356, 3.319265127182007],
      [0.046850644052028656, 2.7746663093566895],
      [0.9752874374389648, 1.9637089967727661],
      [1.0571796894073486, 1.8108587265014648],
      [0.7696126103401184, 1.2826043367385864],
      [1.2491906881332397, 0.1654994636774063],
      [0.8033085465431213, 2.034764528274536],
      [0.9411128163337708, 1.8336278200149536],
      [1.1351237297058105, 1.302860140800476],
      [0.77309650182724, 0.5626922845840454],
      [1.7103430032730103, 1.4338515996932983],
      [1.5167274475097656, 1.2508774995803833],
      [0.9470923542976379, 2.6235716342926025],
      [1.192779302597046, 1.0270487070083618],
      [1.3506262302398682, 0.07253775745630264],
      [1.2538689374923706, 1.1279950141906738],
      [1.4230097532272339, 0.9065091013908386],
      [0.41239750385284424, 3.0621018409729004],
      [1.9388291835784912, 0.23360873758792877],
      [0.9791933298110962, 0.7733169198036194],
      [1.378157138824463, 0.9789188504219055],
      [1.6392107009887695, 0.20650364458560944],
      [1.8036633729934692, 0.4553448259830475],
      [1.094157099723816, 0.5027266144752502],
      ...
    ],
    ...
  ]
>Axon.Metrics.accuracy(value, Nx.tensor([[0], [0], [0], [0], [0], [1], [1], [0]]))#Nx.Tensor<
  f32
  EXLA.Backend
  0.875
>classifier_accuracy_fn.(value, res)#Nx.Tensor<
  f32
  EXLA.Backend
  0.875
># Test new parameters
classifier_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(classifier_accuracy_fn, "GPT Classifier accuracy")
|> Axon.Loop.run(test_dataset, trained_model_state, compiler: EXLA)
14:30:53.927 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 32, GPT Classifier accuracy: 0.9545454%{
  0 => %{
    "GPT Classifier accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.9545454382896423
    >
  }
}# Validation accuracy
classifier_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(classifier_accuracy_fn, "GPT Classifier accuracy")
|> Axon.Loop.run(validation_dataset, trained_model_state, compiler: EXLA)
15:44:18.846 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 16, GPT Classifier accuracy: 0.9705882%{
  0 => %{
    "GPT Classifier accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.970588207244873
    >
  }
}Typically, the validation set accuracy is somewhat higher than the test set accuracy because the model development often involves tuning hyperparameters to perform well on the validation set, which might not generalize as effectively to the test set.
This situation is common, but the gap could potentially be minimized by adjusting the model’s settings, such as increasing the dropout rate or the weight_decay parameter in the optimizer configuration.
# Test old parameters
classifier_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(classifier_accuracy_fn, "GPT Classifier accuracy")
|> Axon.Loop.run(test_dataset, classifier_params, compiler: EXLA)
14:44:05.138 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 32, GPT Classifier accuracy: 0.5037879%{
  0 => %{
    "GPT Classifier accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.5037878751754761
    >
  }
}6.8 Using the LLM as a spam classifier
File.cd(__DIR__)
{:ok, model_bin} = File.read("classifier_gpt.axon") 
classifier_parameters = 
  model_bin
  |> :erlang.binary_to_term()
  |> Nx.deserialize()#Axon.ModelState<
  Parameters: 124441420 (497.77 MB)
  Trainable Parameters: 124441346 (497.77 MB)
  Trainable State: 74, (296 B)
>defmodule MyGPT.Classifier do
  def classify_text(model, model_parameters, text, opt \\ []) when is_function(model) do
    with tokenizer <- Keyword.get(opt, :tokenizer),
         pad_token_id <- Keyword.get(opt, :pad_token_id, 50256),
         tokenizer_output = Bumblebee.apply_tokenizer(tokenizer, text),
         max_length_arg <- Keyword.get(opt, :max_length, nil),
         max_length <- compute_max_length(tokenizer_output, max_length_arg),
         input_vector <- pad_sample(tokenizer_output, pad_token_id, max_length),
         prediction <- model.(model_parameters, input_vector),
         classifier_output <- postprocessing(prediction) do
      parse_classifier(classifier_output)
    end
  end
  defp compute_max_length(tokenizer_output, nil), do: get_input_length(tokenizer_output) + 1
  defp compute_max_length(_tokenizer_output, max_length), do: max_length + 1
  defp get_input_length(tokenizer_output) do
    tokenizer_output
    |> Map.fetch!("input_ids")
    |> Nx.size()
  end
  defp pad_sample(tokenizer_output, pad_token_id, max_length) do
    input_length = Nx.size(tokenizer_output["input_ids"])
    length_diff = max_length - input_length
    tokenizer_output["input_ids"]
    |> Nx.pad(pad_token_id, [{0, 0, 0}, {0, length_diff, 0}])
    |> Nx.reshape({1, max_length})
    |> then(&%{"input_ids" => &1})
  end
  defp postprocessing(prediction) do
    prediction[[.., -1]]
    |> Axon.Activations.softmax(axis: -1)
    |> Nx.argmax(axis: -1)
    |> Nx.gather(Nx.tensor([0]))
    |> Nx.to_number()
  end
  defp parse_classifier(1), do: "spam"
  defp parse_classifier(0), do: "not spam"
end
alias MyGPT.ClassifierMyGPT.Classifiertext = "You are a winner you have been specially selected to receive $1000 cash or a $2000 award."
Classifier.classify_text(predict_fn, classifier_parameters, text, tokenizer: tokenizer)"spam"text = "Hey, just wanted to check if we're still on for dinner tonight? Let me know!"
Classifier.classify_text(predict_fn, classifier_parameters, text, tokenizer: tokenizer)"not spam"Summary
- There are different strategies for fine-tuning LLMs, including classification fine-tuning and intruction fine-tuning.
- Classification fine-tuning involves replacing the output layer of an LLM via a small classification layer.
- In the case of classifying text messages as “spam” or “not spam”, the new clasification layer consists of only two output nodes. Previously, we used the number of output nodes equal to the number of unique tokens in the vocabulary (i.e., 50256).
- Instead of predicting the next token in the text as in pretraining, classification fine-tuning trains the model to output a correct class label-for example, “spam” or “not spam”.
- The model input for fine-tuning is text converted into token IDs, similar to pretraining.
- Before fine-tuning an LLM, we load the pretrained model as a base model.
- Fine-tuning a classification model uses the same cross entropy loss function as when pretraining the LLM.