Chapter 7: Fine-tuning to follow instructions
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:axon, "~> 0.5"},
{:table_rex, "~> 3.1.1"},
{:bumblebee, "~> 0.6.0"},
{:explorer, "~> 0.7.1"},
{:req, "~> 0.4.5"},
{:kino_vega_lite, "~> 0.1.11"},
{:httpoison, "~> 2.0"}
])
Nx.global_default_backend(EXLA.Backend)
Introduction
{:ok, gpt2} = Bumblebee.load_model({:hf, "openai-community/gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai-community/gpt2"})
serving = Bumblebee.Text.generation(gpt2, tokenizer, generation_config)
text_input = Kino.Input.text("Text", default: "Yesterday, I was reading a book and")
text = Kino.Input.read(text_input)
Nx.Serving.run(serving, text)
%{
results: [
%{
text: " I was thinking, \"What's going on here?\" I was thinking, \"What's going on",
token_summary: %{input: 8, output: 20, padding: 0}
}
]
}
%{model: model, params: params} = gpt2
tokenizer =
Bumblebee.configure(tokenizer,
length: nil,
pad_direction: :left,
return_token_type_ids: false,
return_length: true
)
input = Bumblebee.apply_tokenizer(tokenizer, "I want to")
gpt2_model = Axon.nx(model, & &1.logits)
{_init_fn, predict_fn} = Axon.build(gpt2_model)
result = predict_fn.(params, input)
#Nx.Tensor<
f32[1][3][50257]
EXLA.Backend
[
[
[-39.308448791503906, -39.010066986083984, -41.837467193603516, -41.781246185302734, -40.84248352050781, -40.89142990112305, -38.62623596191406, -40.154056549072266, -38.097896575927734, -41.04249954223633, -40.9429931640625, -36.262168884277344, -37.39033889770508, -36.03800964355469, -38.52249526977539, -40.54604721069336, -39.718971252441406, -39.7431640625, -40.27290344238281, -40.314857482910156, -40.54868698120117, -41.00197219848633, -40.9098014831543, -40.914119720458984, -41.297733306884766, -37.69235610961914, -39.106632232666016, -41.460182189941406, -40.526241302490234, -40.43655014038086, -38.97370147705078, -41.32615661621094, -39.90999984741211, -40.565555572509766, -40.7227897644043, -40.8016471862793, -40.875083923339844, -40.86553955078125, -40.39710998535156, -40.221649169921875, -38.78817367553711, -40.58393096923828, -40.43303298950195, -40.767242431640625, -40.72999572753906, -40.78556442260742, -40.461753845214844, -41.084720611572266, -41.600372314453125, -41.25688552856445, ...],
...
]
]
>
7.1 Introduction to instruction fine-tuning
Pre-trained LLMs often struggle with specific instructions. Here, we focus on improving the LLM’s ability to follow such instructions and generate a desired response. Preparing the dataset is a key aspect of instruction fine-tuning.
7.2 Preparing a dataset for supervised instruction fine-tuning
File.cd!(__DIR__)
instructions_data = File.read!("instruction–response.json")
{:ok, dataset} = Jason.decode(instructions_data)
{:ok,
[
%{
"input" => "freind --> friend",
"instruction" => "Evaluate the following phrase by transforming it into the spelling given.",
"output" => "The spelling of the given phrase \"freind\" is incorrect, the correct spelling is \"friend\"."
},
%{
"input" => "He go to the park every day.",
"instruction" => "Edit the following sentence for grammar.",
"output" => "He goes to the park every day."
},
%{
"input" => "",
"instruction" => "Convert 45 kilometers to meters.",
"output" => "45 kilometers is 45000 meters."
},
%{
"input" => "",
"instruction" => "Rewrite this sentence to start with 'Although': Despite the rain, they went for a walk.",
"output" => "Although it was raining, they went for a walk."
},
%{
"input" => "",
"instruction" => "What are the first 10 square numbers?",
"output" => "1, 4, 9, 16, 25, 36, 49, 64, 81, 100."
},
%{
"input" => "",
"instruction" => "Suggest a more formal synonym for \"happy.\"",
"output" => "A more formal synonym for \"happy\" is \"content.\""
},
%{
"input" => "Where is the nearest restaurant?",
"instruction" => "Translate the following sentence into French.",
"output" => "Où est le restaurant le plus proche?"
},
%{
"input" => "",
"instruction" => "Look up the melting point of iron.",
"output" => "The melting point of iron is approximately 1538 degrees Celsius or 2800 degrees Fahrenheit."
},
%{
"input" => "The cake was baked by Sarah.",
"instruction" => "Rewrite the following sentence so that it is in active voice.",
"output" => "Sarah baked the cake."
},
%{
"input" => "",
"instruction" => "Convert the active sentence to passive: 'The manager approved the project.'",
"output" => "The project was approved by the manager."
},
%{
"input" => "",
"instruction" => "What is the contraction for \"will not\"?",
"output" => "The contraction for \"will not\" is \"won't\"."
},
%{
"input" => "The actual fact is that he was late.",
"instruction" => "Rewrite the following sentence to remove redundancy.",
"output" => "The fact is that he was late."
},
%{
"input" => "",
"instruction" => "What is the capital of Indonesia?",
"output" => "The capital of Indonesia is Jakarta."
},
%{
"input" => "",
"instruction" => "Provide a synonym for 'bright'.",
"output" => "A synonym for 'bright' is 'radiant'."
},
%{
"input" => "",
"instruction" => "Suggest an alternate word for 'happy'.",
"output" => "joyful."
},
%{
"input" => "",
"instruction" => "Name the device that measures atmospheric pressure.",
"output" => "The device that measures atmospheric pressure is a barometer."
},
%{
"input" => "",
"instruction" => "List two antonyms for 'intelligent.'",
"output" => "Two antonyms for 'intelligent' are 'stupid' and 'dumb.'"
},
%{
"input" => "",
"instruction" => "What is the pH of a neutral solution at 25°C?",
"output" => "The pH of a neutral solution at 25°C is 7."
},
%{
"input" => "'Her eyes are like stars.'",
"instruction" => "Rewrite the following sentence to use a metaphor",
"output" => "Her eyes are stars in the night sky."
},
%{
"input" => "",
"instruction" => "Name three types of biomes.",
"output" => "Forest, desert, and grassland."
},
%{
"input" => "",
"instruction" => "Identify the climax in 'Romeo and Juliet'.",
"output" => "The climax of 'Romeo and Juliet' occurs when Romeo kills Tybalt, leading to his banishment and the subsequent tragic events."
},
%{
"input" => "The dog is very loyal.",
"instruction" => "Rewrite the sentence using a simile.",
"output" => "The dog is as loyal as a soldier."
},
%{
"input" => "The storm started suddenly.",
"instruction" => "Rewrite the sentence using an idiom.",
"output" => "The storm came out of the blue."
},
%{
"input" => "",
"instruction" => "Convert 10 feet to meters.",
"output" => "10 feet is equal to approximately 3.048 meters."
},
%{
"input" => "The chef cooked a delicious meal.",
"instruction" => "Convert this sentence to passive voice",
"output" => "A delicious meal was cooked by the chef."
},
%{
"input" => "",
"instruction" => "What is the plural form of 'mouse'?",
"output" => "The plural form of 'mouse' is 'mice'."
},
%{
"input" => "",
"instruction" => "What is the opposite of 'thin'?",
"output" => "The opposite of 'thin' is 'thick'."
},
%{
"input" => "",
"instruction" => "What is the state capital of California?",
"output" => "The state capital of California is Sacramento."
},
%{
"input" => "",
"instruction" => "Translate 'family' into Italian.",
"output" => "The Italian word for 'family' is 'famiglia'."
},
%{
"input" => "",
"instruction" => "What is the past tense of 'draw'?",
"output" => "The past tense of 'draw' is 'drew.'"
},
%{
"input" => "The ship sails at dawn.",
"instruction" => "Alter the content of the sentence to use the past tense.",
"output" => "The ship sailed at dawn."
},
%{
"input" => "",
"instruction" => "Generate a sentence that follows the pattern: 'Rarely do I _____'",
"output" => "Rarely do I eat fast food."
},
%{
"input" => "",
"instruction" => "What is the atomic number of oxygen?",
"output" => "The atomic number of oxygen is 8."
},
%{
"input" => "",
"instruction" => "What is a synonym of 'luminous'?",
"output" => "A synonym for 'luminous' is 'bright'."
},
%{
"input" => "",
"instruction" => "What is the capital of Brazil?",
"output" => "The capital of Brazil is Brasília."
},
%{
"input" => "",
"instruction" => "List the ingredients required to make a basic salad.",
"output" => "Ingredients for a basic salad typically include lettuce, tomatoes, cucumbers, and dressing."
},
%{
"input" => "It's hot outside.",
"instruction" => "Rewrite the sentence using a synonym.",
"output" => "It's scorching outside."
},
%{
"input" => "",
"instruction" => "Provide a synonym for 'happy'.",
"output" => "A synonym for 'happy' is 'content'."
},
%{
"input" => "",
"instruction" => "What is the derivative of x^3 + 5x^2 - 3x + 7 with respect to x?",
"output" => "3x^2 + 10x - 3."
},
%{
"input" => "She felt a sense of relief.",
"instruction" => "Identify the point of view used in the sentence.",
"output" => "The point of view used in the sentence is third person."
},
%{
"input" => "",
"instruction" => "Generate a sentence using the word 'gregarious'.",
"output" => "He was gregarious, making friends wherever he went."
},
%{
"input" => "",
"instruction" => "What is the atomic number of carbon?",
"output" => "The atomic number of carbon is 6."
},
%{
"input" => "",
"instruction" => "What is the formula for area of a triangle?",
"output" => "The formula for the area of a triangle is 1/2 * base * height."
},
%{
"input" => "Zebra, Elephant, Crocodile",
"instruction" => "Sort the following list in alphabetical order.",
"output" => "Crocodile, Elephant, Zebra"
},
%{
"input" => "Moon and stars",
"instruction" => "Reverse the order of the given phrase.",
"output" => "Stars and moon"
},
%{"input" => "", "instruction" => "What is a synonym for 'begin'?", ...},
%{"input" => "", ...},
%{...},
...
]}
data_length = Enum.count(dataset)
1100
Enum.at(dataset, 50)
%{
"input" => "Ocassion",
"instruction" => "Identify the correct spelling of the following word.",
"output" => "The correct spelling is 'Occasion.'"
}
Instruction fine-tuning involves training a model on a dataset where the input-output pairs, like those we extracted from the JSON file, are explicitly provided. There are various methods to format these entries for LLMs. There are two different approaches (prompt styles):
- Alpaca prompt style template.
The Alpaca style uses a structured format with defined sections for instruction, input, and response.
**Below is an instruction that describes a task. Write a response that appropriately completes the request.**
### Instruction:
Identify the correct spelling of the following word.
### Input:
Ocassion
### Response:
The correct spelling is 'Occasion'
- Phi-3 prompt style template.
The Phi-3 style employs a simpler format with designated <|user|> and <|assistant|> tokens.
<|user|>
Identify the correct spelling of the following word: 'Ocassion'
<|assistant>
The correct spelling is 'Occasion'.
Alpaca was one of the early LLMs to publicly detail its instruction fine-tuning process. Phi-3, developed by Microsoft, is included to demonstrate the diversity in prompt styles.
defmodule MyGPT.Assistant.Formatter do
def build_instruction(json) do
instruction =
"""
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
#{json["instruction"]}\
"""
instruction <> build_input(json)
end
def build_response(json), do: "\n\n### Response:\n#{json["output"]}"
defp build_input(%{"input" => input_text}) when input_text == "", do: ""
defp build_input(%{"input" => input_text}) do
"\n\n### Input:\n#{input_text}"
end
end
alias MyGPT.Assistant.Formatter
MyGPT.Assistant.Formatter
prompt =
dataset
|> Enum.at(50)
|> Formatter.build_instruction()
IO.puts(prompt)
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Identify the correct spelling of the following word.
### Input:
Ocassion
:ok
response =
dataset
|> Enum.at(50)
|> Formatter.build_response()
IO.puts(response)
### Response:
The correct spelling is 'Occasion.'
:ok
(prompt <> response) |> IO.puts
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Identify the correct spelling of the following word.
### Input:
Ocassion
### Response:
The correct spelling is 'Occasion.'
:ok
prompt =
dataset
|> Enum.at(999)
|> Formatter.build_instruction()
response =
dataset
|> Enum.at(999)
|> Formatter.build_response()
(prompt <> response) |> IO.puts
Enum.at(dataset, 999)
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
What is an antonym of 'complicated'?
### Response:
An antonym of 'complicated' is 'simple'.
%{
"input" => "",
"instruction" => "What is an antonym of 'complicated'?",
"output" => "An antonym of 'complicated' is 'simple'."
}
train_length = floor(data_length * 0.85)
test_length = floor(data_length * 0.1)
dataset = Enum.shuffle(dataset)
{train_dataset, dataset_tail} = Enum.split(dataset, train_length)
{test_dataset, validation_dataset} = Enum.split(dataset_tail, test_length)
IO.puts("Training set length: #{Enum.count(train_dataset)}")
IO.puts("Validation set length: #{Enum.count(validation_dataset)}")
IO.puts("Test set length: #{Enum.count(test_dataset)}")
Training set length: 935
Validation set length: 55
Test set length: 110
:ok
7.3 Organizing data into training batches
A collate function is responsible for taking a list of individual data samples and merging them into a single batch that can be processed efficiently by the model during training.
defmodule MyGPT.Assistant.Dataset do
import MyGPT.Assistant.Formatter
def build(json_dataset, tokenizer, opt \\ []) do
with max_length <- Keyword.get(opt, :max_length),
pad_token_id <- Keyword.get(opt, :pad_token_id, 50256),
exclude_token_id <- Keyword.get(opt, :exclude_token_id, -100),
batch_size <- Keyword.get(opt, :batch_size, 2) do
json_stream = Stream.chunk_every(json_dataset, batch_size)
for json_batch <- json_stream, reduce: [] do
acc ->
prompt_batch =
json_batch
|> Stream.map(&build_prompt(&1))
|> Stream.map(&encode_prompt(&1, tokenizer))
max_length_in_batch =
prompt_batch
|> Enum.map(fn encoded_prompt -> Nx.size(encoded_prompt["input_ids"]) end)
|> Enum.max()
|> then(&(max_length || &1))
input =
prompt_batch
|> Enum.map(&build_input(&1, pad_token_id, max_length_in_batch))
|> Nx.stack()
|> then(&(%{"input_ids" => &1}))
label =
prompt_batch
|> Enum.map(&build_label(&1, pad_token_id, exclude_token_id, max_length_in_batch))
|> Nx.stack()
acc ++ [{input, label}]
end
end
end
def build_prompt(json), do: build_instruction(json) <> build_response(json)
def encode_prompt(prompt, tokenizer), do: Bumblebee.apply_tokenizer(tokenizer, prompt)
def build_input(%{"input_ids" => encoded_prompt}, pad_token_id, max_length) do
input_length = Nx.size(encoded_prompt)
length_diff = max_length - input_length
encoded_prompt
|> Nx.pad(pad_token_id, [{0, 0, 0}, {0, length_diff, 0}])
|> Nx.reshape({max_length})
end
def build_label(%{"input_ids" => encoded_prompt}, pad_token_id, exclude_token_id, max_length) do
input_length = Nx.size(encoded_prompt)
label_tensor =
encoded_prompt[[.., 1..(input_length-1)]]
|> Nx.pad(pad_token_id, [{0, 0, 0}, {0, 1, 0}])
length_diff = max_length - input_length
label_tensor
|> Nx.pad(exclude_token_id, [{0, 0, 0}, {0, length_diff, 0}])
|> Nx.reshape({max_length})
end
end
alias MyGPT.Assistant.Dataset
MyGPT.Assistant.Dataset
res = Dataset.build(train_dataset, tokenizer, max_length: 92)
[
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 13061, 1573, 1262, 366, 2070, 526, 198, 198, 21017, 18261, 25, 198, 32, 13061, 1573, 1262, 366, 2070, 1, 318, 366, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 13061, 1573, 1262, 366, 2070, 526, 198, 198, 21017, 18261, 25, 198, 32, 13061, 1573, 1262, 366, 2070, 1, 318, 366, 2070, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, 543, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8567, 510, 262, 24372, 966, 286, 28829, 13, 198, 198, 21017, 18261, 25, 198, 464, 24372, 966, 286, 28829, 318, 6702, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8567, 510, 262, 24372, 966, 286, 28829, 13, 198, 198, 21017, 18261, 25, 198, 464, 24372, 966, 286, 28829, 318, 6702, 8699, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 2347, 422, 37075, 284, 16379, 13, 198, 198, 21017, 23412, 25, 198, 19, 37075, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 2347, 422, 37075, 284, 16379, 13, 198, 198, 21017, 23412, 25, 198, 19, 37075, 198, 198, 21017, 18261, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 1542, 48829, 284, 10700, 13, 198, 198, 21017, 18261, 25, 198, 1270, 48829, 318, 657, 13, 18, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 1542, 48829, 284, 10700, 13, 198, 198, 21017, 18261, 25, 198, 1270, 48829, 318, 657, 13, 18, 10700, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 4215, 30960, 198, 198, 21017, 18261, 25, 198, 464, 6697, 286, 705, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 4215, 30960, 198, 198, 21017, 18261, 25, 198, 464, 6697, 286, 705, 4215, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 6827, 284, 779, 281, 35639, 28636, 13, 198, 198, 21017, 23412, 25, 198, 28211, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 6827, 284, 779, 281, 35639, 28636, 13, 198, 198, 21017, 23412, 25, 198, 28211, 1364, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 18438, 391, 262, 1429, 286, 33607, 13, 198, 198, 21017, 18261, 25, 198, 20575, 436, 295, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 18438, 391, 262, 1429, 286, 33607, 13, 198, 198, 21017, 18261, 25, 198, 20575, 436, 295, 318, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 281, 281, 1122, 4948, 286, 705, 4215, 30960, 198, 198, 21017, 18261, 25, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 281, 281, 1122, 4948, 286, 705, 4215, 30960, 198, 198, 21017, 18261, 25, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 262, 1708, 6827, 656, 4141, 13, 198, 198, 21017, 23412, 25, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 262, 1708, 6827, 656, 4141, 13, 198, 198, 21017, 23412, 25, 198, 8496, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 389, 262, 717, 838, 6616, 3146, 30, 198, 198, 21017, 18261, 25, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 389, 262, 717, 838, 6616, 3146, 30, 198, 198, 21017, 18261, 25, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 18378, 262, 1708, 6827, 284, 787, 340, 517, 8766, 13, 198, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 18378, 262, 1708, 6827, 284, 787, 340, 517, 8766, 13, 198, 198, 21017, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 3139, 286, 2520, 4969, 30, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 3139, 286, 2520, 4969, 30, 198, 198, 21017, 18261, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 281, 281, 1122, 4948, 329, 705, 14261, 2637, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 281, 281, 1122, 4948, 329, 705, 14261, 2637, 198, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, 10248, 1755, 6, 656, 2679, 13, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, 10248, 1755, 6, 656, 2679, 13, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 22918, 1241, 286, 5899, 1660, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 22918, 1241, 286, 5899, 1660, 30, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 44402, 257, 11080, 43441, 284, 1844, 262, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 44402, 257, 11080, 43441, 284, 1844, 262, 6827, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16742, 257, 6171, 5177, 329, 262, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16742, 257, 6171, 5177, 329, 262, 1573, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 4387, 9151, 319, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 4387, 9151, 319, 3668, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24372, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24372, 966, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 1708, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
...
]
>},
{%{...}, ...},
{...},
...
]
{input, label} = Enum.at(res, 1)
# With Nx.Batch.stack
#input = Nx.Defn.jit_apply(&Function.identity/1, [input])
#label = Nx.Defn.jit_apply(&Function.identity/1, [label])
IO.inspect({input, label}, limit: :infinity)
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 318, 1444, 1007, 10514, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, 286, 705, 417, 515, 30960, 198, 198, 21017, 18261, 25, 198, 32, 6171, 5177, 329, 705, 417, 515, 6, 318, 705, 2502, 2633, 276, 4458, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 318, 1444, 1007, 10514, 13, 50256, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, 286, 705, 417, 515, 30960, 198, 198, 21017, 18261, 25, 198, 32, 6171, 5177, 329, 705, 417, 515, 6, 318, 705, 2502, 2633, 276, 4458, 50256, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]
]
>}
{%{
"input_ids" => #Nx.Tensor<
u32[2][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, 543, 6134, ...],
...
]
>
},
#Nx.Tensor<
s64[2][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 262, 1429, 416, 543, 6134, 4425, 1660, 20199, 832, 511, 5667, 13, 198, 198, 21017, 18261, 25, 198, 464, 1429, 416, 543, 6134, 4425, ...],
...
]
>}
We assign a -100 placeholder value to all padding tokens. This special value allows us to exclude these padding tokens from contributing to the training loss calculation, ensuring that only meaningful data influences model learning.
However, we retain one end-of-text token, ID 50256. Retaining it allows the LLM to learn when to generate an end-of-text token in response to instructions, which we use as an indicator that the generated response is complete.
The default setting of the cross entropy function in PyTorch is cross_entropy(…, ignore_index=-100). This means that it ignores targets labeled with -100. We take advantage of this ignore_index to ignore the additional end-of-text (padding) tokens that we used to pad the training examples to have the same length in each batch.
However, we want to keep one 50256 (end-of-text) token ID in the targets because it helps the LLM to learn to generate end-of-text tokens . In addition to masking out padding tokens, it is also common to mask out the target token IDs that correspond to the instruction. By masking out the LLM’s target token IDs corresponding to the instruction, the cross entropy loss is only computed for the generated response target IDs. Thus, the model is trained to focus on generating accurate responses rather than memorizing instructions, which can help reduce overfitting.
As of this writing, researchers are divided on whether masking the instructions is universally beneficial during instruction fine-tuning.
7.4 Creating data loaders for an instruction dataset
train_stream = Dataset.build(train_dataset, tokenizer, batch_size: 5, max_length: 92)
validation_stream = Dataset.build(validation_dataset, tokenizer, batch_size: 5, max_length: 92)
test_stream = Dataset.build(test_dataset, tokenizer, batch_size: 5, max_length: 92)
[
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 10451, 329, 23202, 34186, 284, 35935, 30, 198, 198, 21017, 18261, 25, 198, 464, 10451, 329, 23202, 34186, 284, 35935, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 10451, 329, 23202, 34186, 284, 35935, 30, 198, 198, 21017, 18261, 25, 198, 464, 10451, 329, 23202, 34186, 284, 35935, 318, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 5931, 10451, 329, 25006, 30, 198, 198, 21017, 18261, 25, 198, 464, 5931, 10451, 329, 25006, 318, 5870, 19, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 5931, 10451, 329, 25006, 30, 198, 198, 21017, 18261, 25, 198, 464, 5931, 10451, 329, 25006, 318, 5870, 19, 13, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, 329, 705, 562, 396, 30960, 198, 198, 21017, 18261, 25, 198, 32, 6171, 5177, 329, 705, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, 329, 705, 562, 396, 30960, 198, 198, 21017, 18261, 25, 198, 32, 6171, 5177, 329, 705, 562, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 257, 5337, 3194, 416, 12091, 2517, 268, 13, 198, 198, 21017, 18261, 25, 198, 3198, 286, 262, 16122, 3194, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 257, 5337, 3194, 416, 12091, 2517, 268, 13, 198, 198, 21017, 18261, 25, 198, 3198, 286, 262, 16122, 3194, 416, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 428, 6827, 284, 11005, 262, 14513, 3809, 13, 198, 198, 21017, 23412, 25, 198, 464, 12187, 373, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 428, 6827, 284, 11005, 262, 14513, 3809, 13, 198, 198, 21017, 23412, 25, 198, 464, 12187, 373, 925, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 4075, 6827, 284, 14513, 25, 705, 16980, 544, 12542, 262, 2613, 2637, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 3103, 1851, 262, 4075, 6827, 284, 14513, 25, 705, 16980, 544, 12542, 262, 2613, 2637, 198, 198, 21017, 18261, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 12187, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 30003, 6525, 262, 6827, 1262, 257, 985, 576, 13, 198, 198, 21017, 23412, 25, 198, 464, 12187, 318, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, 40, 1842, 345, 6, 656, 7897, 13, 198, 198, 21017, 18261, 25, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, 40, 1842, 345, 6, 656, 7897, 13, 198, 198, 21017, 18261, 25, 198, 6767, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 7527, 30960, 198, 198, 21017, 18261, 25, 198, 464, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 7527, 30960, 198, 198, 21017, 18261, 25, 198, 464, 6697, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 257, 21025, 2288, 1317, 973, 284, 7603, 257, 1808, 13, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 5376, 257, 21025, 2288, 1317, 973, 284, 7603, 257, 1808, 13, 198, 198, 21017, 18261, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 24988, 589, 30960, 198, 198, 21017, 18261, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 6697, 286, 705, 24988, 589, 30960, 198, 198, 21017, 18261, 25, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 6827, 1262, 262, 1573, 705, 2545, 1010, 7277, 4458, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 6827, 1262, 262, 1573, 705, 2545, 1010, 7277, 4458, 198, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 42779, 262, 21025, 2288, 287, 262, 6827, 13, 198, 198, 21017, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 42779, 262, 21025, 2288, 287, 262, 6827, 13, 198, 198, 21017, 23412, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24372, 966, 286, 27394, 287, 34186, 30, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24372, 966, 286, 27394, 287, 34186, 30, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 7469, 500, 262, 3721, 286, 705, 46453, 4458, 198, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 7469, 500, 262, 3721, 286, 705, 46453, 4458, 198, 198, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 2866, 286, 2128, 287, 1633, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 2866, 286, 2128, 287, 1633, 30, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 42758, 262, 1708, 1351, 287, 24830, 605, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 42758, 262, 1708, 1351, 287, 24830, 605, 1502, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24203, 966, 286, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, 24203, 966, 286, 6953, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 257, 6171, 5177, 286, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 34, 47467, 1096, 262, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 34, 47467, 1096, 262, 1708, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 8291, 17660, 705, 40, ...],
...
]
>},
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 2061, 318, 262, ...],
...
]
>}
]
{input, label} = Enum.at(train_stream, 0)
#input = Nx.Defn.jit_apply(&Function.identity/1, [input])
#label = Nx.Defn.jit_apply(&Function.identity/1, [label])
{input, label}
{%{
"input_ids" => #Nx.Tensor<
u32[5][92]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 13061, 1573, 1262, 366, 2070, 526, 198, 198, 21017, 18261, 25, 198, 32, 13061, 1573, 1262, 366, 2070, 1, 318, 366, 2070, ...],
...
]
>
},
#Nx.Tensor<
s64[5][92]
EXLA.Backend
[
[318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 16447, 257, 13061, 1573, 1262, 366, 2070, 526, 198, 198, 21017, 18261, 25, 198, 32, 13061, 1573, 1262, 366, 2070, 1, 318, 366, 2070, 7091, ...],
...
]
>}
7.5 Loading a pretrained LLM
Instead of using the smallest 124-million-parameter model as before, we load the medium-sized model with 355 million parameters. The reason for this choice is that the 124-million-parameter model is too limited in capacity to achieve satisfactory results via instruction fine-tuning.
Specifically, smaller models lack the necessary capacity to learn and retain the intricate patterns and nuanced behaviors required for high-quality instruction-following tasks.
{:ok, gpt2_med} = Bumblebee.load_model({:hf, "openai-community/gpt2-medium"})
{:ok, tokenizer_med} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2-medium"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai-community/gpt2-medium"})
serving = Bumblebee.Text.generation(gpt2_med, tokenizer, generation_config)
text_input = Kino.Input.text("Text", default: "Yesterday, I was reading a book and")
text = Kino.Input.read(text_input)
Nx.Serving.run(serving, text)
%{
results: [
%{
text: " I was thinking, \"I wonder if this is the same person who wrote the book about the '",
token_summary: %{input: 8, output: 20, padding: 0}
}
]
}
tokenizer_med
%Bumblebee.Text.PreTrainedTokenizer{
native_tokenizer: #Tokenizers.Tokenizer<[
vocab_size: 50257,
byte_fallback: false,
continuing_subword_prefix: "",
dropout: nil,
end_of_word_suffix: "",
fuse_unk: false,
model_type: "bpe",
unk_token: nil
]>,
type: :gpt2,
special_tokens: %{bos: "<|endoftext|>", eos: "<|endoftext|>", unk: "<|endoftext|>"},
additional_special_tokens: [],
add_special_tokens: true,
length: nil,
pad_direction: :right,
truncate_direction: :right,
return_attention_mask: true,
return_token_type_ids: true,
return_special_tokens_mask: false,
return_offsets: false,
return_length: false,
template_options: []
}
%{model: model_med, params: params_med} = gpt2_med
tokenizer =
Bumblebee.configure(tokenizer_med,
length: nil,
pad_direction: :left,
return_token_type_ids: false,
return_length: true
)
input = Bumblebee.apply_tokenizer(tokenizer, "I want to")
gpt2_model_med = Axon.nx(model_med, & &1.logits)
{_init_fn, predict_fn_med} = Axon.build(gpt2_model_med)
result = predict_fn_med.(params_med, input)
#Nx.Tensor<
f32[1][3][50257]
EXLA.Backend
[
[
[-78.32323455810547, -77.60397338867188, -80.8931655883789, -81.26132202148438, -81.70814514160156, -78.69278717041016, -75.68329620361328, -80.34996032714844, -76.3667221069336, -79.60734558105469, -79.98731231689453, -75.5054702758789, -76.88710021972656, -76.00344848632812, -77.07390594482422, -81.04473876953125, -80.06044006347656, -79.47236633300781, -79.81079864501953, -80.26091003417969, -79.70281982421875, -81.00440216064453, -80.54666900634766, -81.11274719238281, -81.23478698730469, -77.00039672851562, -79.41563415527344, -81.40238189697266, -79.31681060791016, -80.09330749511719, -77.53793334960938, -79.27008056640625, -79.0863037109375, -79.61581420898438, -80.37324523925781, -79.28718566894531, -80.15486145019531, -80.31566619873047, -79.55355072021484, -78.17181396484375, -77.77738952636719, -78.36002349853516, -78.64955139160156, -79.65756225585938, -79.86961364746094, -78.8159408569336, -78.95235443115234, -80.20915985107422, -79.78424835205078, -80.28970336914062, ...],
...
]
]
>
gpt2_med
%{
spec: %Bumblebee.Text.Gpt2{
architecture: :for_causal_language_modeling,
vocab_size: 50257,
max_positions: 1024,
hidden_size: 1024,
num_blocks: 24,
num_attention_heads: 16,
intermediate_size: nil,
activation: :gelu_approx_tanh,
scale_attention_weights: true,
dropout_rate: 0.1,
embeddings_dropout_rate: 0.1,
attention_dropout_rate: 0.1,
classifier_dropout_rate: 0.1,
layer_norm_epsilon: 1.0e-5,
initializer_scale: 0.02,
num_labels: 2,
id_to_label: %{},
use_cross_attention: false,
pad_token_id: 50256
},
params: #Axon.ModelState<
Parameters: 406286336 (1.63 GB)
Trainable Parameters: 406286336 (1.63 GB)
Trainable State: 0, (0 B)
>,
model: #Axon<
inputs: %{"attention_head_mask" => {24, 16}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 1024}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
outputs: "container_52"
nodes: 1365
>
}
val_0 =
validation_dataset
|> Enum.at(0)
|> Formatter.build_instruction()
IO.puts(val_0)
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Identify the tense used in the sentence.
### Input:
They are watching a movie.
:ok
defmodule MyGPT.Assistant do
def text_to_token_ids(tokenizer, text), do: Bumblebee.apply_tokenizer(tokenizer, text)
def token_ids_to_text(tokenizer, tokens_ids),
do: Bumblebee.Tokenizer.decode(tokenizer, tokens_ids)
def generate_text(
tokenizer,
predict_fn,
model_params,
text,
max_new_token,
k \\ 0,
temperature \\ 1,
eof \\ 50256
) do
token_ids = text_to_token_ids(tokenizer, text)
predicted_tokens_ids =
generate_tokens_impl(predict_fn, model_params, token_ids, max_new_token, k, temperature)
trimmed_predicted_tokens_ids = try_to_trim_prediction(predicted_tokens_ids["input_ids"], eof)
predicted_text = token_ids_to_text(tokenizer, trimmed_predicted_tokens_ids)
{predicted_text, predicted_tokens_ids}
end
defp generate_tokens_impl(predict_fn, model_params, input, max_new_token, k, temperature) do
for _new_token_index <- 1..max_new_token, reduce: input do
%{"input_ids" => input_ids} = input_acc ->
logit = predict_fn.(model_params, input_acc)
# Get last element of the vector.
predicted_new_token =
logit[[.., -1]]
|> top_k(k)
|> softmax_with_temperature(temperature)
|> Nx.new_axis(0)
%{"input_ids" => Nx.concatenate([input_ids, predicted_new_token], axis: 1)}
end
end
defp multinomial(probabilities, num_samples, max_random_number \\ 1000) do
seed = :rand.uniform(max_random_number)
key = Nx.Random.key(seed)
{random_values, _new_key} = Nx.Random.uniform(key, shape: {num_samples})
cumulative_probs = Nx.cumulative_sum(probabilities, axis: -1)
Enum.map(Nx.to_flat_list(random_values), fn value ->
Enum.find_index(
Nx.to_flat_list(cumulative_probs),
fn prob -> prob >= value end
)
end)
end
defp softmax_with_temperature(logits, temperature) when temperature < 0,
do: Axon.Layers.softmax(logits, axis: -1) |> Nx.argmax(axis: -1)
defp softmax_with_temperature(logits, temperature) when temperature > 0 do
scaled_logits = Nx.divide(logits, temperature)
Axon.Layers.softmax(scaled_logits, axis: -1)
|> multinomial(1)
|> Nx.tensor()
end
defp top_k(logits, k) when k == 0, do: logits
defp top_k(logits, k) do
{top_logits, _top_pos} = Nx.top_k(logits, k: k)
min_index = Nx.reduce_min(top_logits)
neg_inf_tensor = Nx.broadcast(Nx.Constants.neg_infinity(), logits.shape)
Nx.select(Nx.less(logits, min_index), neg_inf_tensor, logits)
end
defp try_to_trim_prediction(prediction_token_ids, eof) do
has_eof? = prediction_token_ids |> Nx.to_binary() |> String.contains?(<>)
eof_index = Nx.argmax(prediction_token_ids) |> Nx.to_flat_list() |> Enum.at(0)
trim_prediction(has_eof?, prediction_token_ids, eof_index)
end
defp trim_prediction(false, prediction_token_ids, _eof_index), do: prediction_token_ids
defp trim_prediction(_has_eof?, prediction_token_ids, eof_index),
do: prediction_token_ids[[.., 0..eof_index]]
end
alias MyGPT.Assistant
MyGPT.Assistant
token_ids = Bumblebee.apply_tokenizer(tokenizer_med, "Esto es una oración")
%{
"attention_mask" => #Nx.Tensor<
u32[1][8]
EXLA.Backend
[
[1, 1, 1, 1, 1, 1, 1, 1]
]
>,
"input_ids" => #Nx.Tensor<
u32[1][8]
EXLA.Backend
[
[22362, 78, 1658, 555, 64, 393, 32009, 18840]
]
>,
"token_type_ids" => #Nx.Tensor<
u32[1][8]
EXLA.Backend
[
[0, 0, 0, 0, 0, 0, 0, 0]
]
>
}
Bumblebee.Tokenizer.decode(tokenizer_med, token_ids["input_ids"])
["Esto es una oración"]
text = "Esto es una oración"
Assistant.generate_text(tokenizer_med, predict_fn_med, params_med, text, 5, 0, 0.1)
{["Esto es una oración de la vida de"],
%{
"input_ids" => #Nx.Tensor<
s64[1][13]
EXLA.Backend
[
[22362, 78, 1658, 555, 64, 393, 32009, 18840, 390, 8591, 410, 3755, 390]
]
>
}}
{[predicted_text], prediction} = Assistant.generate_text(tokenizer_med, predict_fn_med, params_med, val_0, 20, 0, 0.1)
IO.inspect(prediction, limit: :infinity)
IO.puts(predicted_text)
%{
"input_ids" => #Nx.Tensor<
s64[1][65]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 33234, 1958, 262, 20170, 973, 287, 262, 6827, 13, 198, 198, 21017, 23412, 25, 198, 2990, 389, 4964, 257, 3807, 13, 198, 198, 21017, 25235, 25, 198, 198, 2990, 389, 4964, 257, 3807, 13, 198, 198, 21017, 46486, 25, 198, 198]
]
>
}
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Identify the tense used in the sentence.
### Input:
They are watching a movie.
### Output:
They are watching a movie.
### Instruction:
:ok
7.6 Fine-tuning the LLM on instruction data (GPT 2)
train_stream |> Enum.map(fn {inputs, _labels} -> Nx.size(inputs["input_ids"])/8 end) |> Enum.max
57.5
loss_fn = fn y_true, y_pred ->
logits_flat = Nx.flatten(y_pred, axes: [0, 1])
targets_flat = Nx.flatten(y_true)
Axon.Losses.categorical_cross_entropy(targets_flat, logits_flat,
reduction: :mean,
from_logits: true,
sparse: true
)
end
optimizer = Polaris.Optimizers.adam(learning_rate: 0.004)
# Errors: when using gpt2_medium
# W0000 00:00:1732415505.238798 986449 onednn_matmul.cc:319] [Perf]: MatMul reference implementation being executed
# It seems Axon doesnt support multi-length dataset
trained_model_state =
gpt2_model
|> Axon.Loop.trainer(loss_fn, optimizer)
|> Axon.Loop.validate(gpt2_model, validation_stream)
|> Axon.Loop.run(train_stream, params, epochs: 2, compiler: EXLA)
23:52:15.100 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
File.cd(__DIR__)
model_bin = Nx.serialize(trained_model_state)
model_erlang_binary = :erlang.term_to_binary(model_bin)
File.write!("assistant_gpt.axon", model_erlang_binary)
:ok
7.6 Fine-tuning the LLM on instruction data (GPT 2 Med)
File.cd(__DIR__)
{:ok, model_bin} = File.read("assistant_gpt.axon")
assistant_parameters =
model_bin
|> :erlang.binary_to_term()
|> Nx.deserialize()
#Axon.ModelState<
Parameters: 163037258 (652.15 MB)
Trainable Parameters: 163037184 (652.15 MB)
Trainable State: 74, (296 B)
>
test_sample = train_dataset |> Enum.at(10) |> Formatter.build_instruction()
IO.puts(test_sample)
test_sample
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Define the term 'oxymoron'.
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDefine the term 'oxymoron'."
{[predicted_text], prediction} = Assistant.generate_text(tokenizer, predict_fn, params, test_sample, 20, 0, 0.1)
IO.inspect(prediction, limit: :infinity)
IO.puts(predicted_text)
%{
"input_ids" => #Nx.Tensor<
s64[1][53]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 7469, 500, 262, 3381, 705, 23536, 4491, 261, 4458, 198, 198, 21017, 18261, 25, 198, 198, 7469, 500, 262, 3381, 705, 23536, 4491, 261, 4458, 198, 198, 21017, 18261]
]
>
}
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Define the term 'oxymoron'.
### Response:
Define the term 'oxymoron'.
### Response
:ok
{[predicted_text], prediction} = Assistant.generate_text(tokenizer, predict_fn, assistant_parameters, test_sample, 20, 0, 0.01)
IO.inspect(prediction, limit: :infinity)
IO.puts(predicted_text)
%{
"input_ids" => #Nx.Tensor<
s64[1][53]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 7469, 500, 262, 3381, 705, 23536, 4491, 261, 4458, 198, 198, 21017, 18261, 25, 198, 2025, 281, 1122, 4948, 329, 705, 11576, 4458, 50256, 198, 21017, 18261, 25, 198]
]
>
}
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Define the term 'oxymoron'.
### Response:
An antonym for 'strong'.
:ok
7.7 Extracting and saving responses
For Evaluating the LLM, extract and collect the model responses on the held-out test dataset for further analysis and then evaluate the model to quantify the performance of the instruction-fine-tuned LLM.
In practice, instruction-fine-tuned LLMs such as chatbots are evaluated via multiple approaches:
- Short-answer and multiple-choice benchmarks, such as Measuring Massive Multitask Language Understanding, which test the general knowledge of a model.
- Human preference comparison to other LLMs, such as LMSYS chatbot arena.
- Automated conversational benchmarks, where another LLM like GPT-4 is used to evaluate the responses, such as AlpacaEval.
Conversational performance
Conversational performance of LLMs refers to their ability to engage in human-like communication by understanding context, nuance, and intent. It encompasses skills such as providing relevant and coherent responses, maintaining consistency, and adapting to different topics and styles of interaction.
Human evaluation, while providing valuable insights, can be relatively laborious and time-consuming, especially when dealing with a large number of responses. For instance, reading and assigning ratings to all 1,100 responses would require a significant amount of effort.
File.cd(__DIR__)
{:ok, model_bin} = File.read("assistant_gpt.axon")
assistant_parameters =
model_bin
|> :erlang.binary_to_term()
|> Nx.deserialize()
test_sample = test_dataset |> Enum.at(10) |> Formatter.build_instruction()
{[predicted_text], prediction} = Assistant.generate_text(tokenizer, predict_fn, assistant_parameters, test_sample, 20, 0, 0.01)
IO.inspect(prediction, limit: :infinity)
IO.puts(predicted_text)
%{
"input_ids" => #Nx.Tensor<
s64[1][52]
EXLA.Backend
[
[21106, 318, 281, 12064, 326, 8477, 257, 4876, 13, 19430, 257, 2882, 326, 20431, 32543, 262, 2581, 13, 198, 198, 21017, 46486, 25, 198, 18438, 391, 644, 257, 387, 28643, 318, 13, 198, 198, 21017, 18261, 25, 198, 464, 10451, 318, 257, 1767, 318, 257, 2163, 286, 262, 1767, 13, 50256, 198]
]
>
}
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Explain what a haiku is.
### Response:
The formula is a body is a function of the body.
:ok
# Remove instruction and input.
model_output = String.replace(predicted_text, test_sample, "")
model_response = String.replace(model_output, "\n\n### Response:\n", "")
test = test_dataset |> Enum.at(10)
Map.put(test, "model_response", model_response)
%{
"input" => "",
"instruction" => "Explain what a haiku is.",
"model_response" => "The formula is a body is a function of the body.",
"output" => "A haiku is a form of traditional Japanese poetry that consists of three lines with a syllable pattern of 5-7-5."
}
7.8 Evaluating the fine-tuned LLM
To evaluate test set responses in an automated fashion, we utilize an existing instruction-fine-tuned 8-billion-parameter Llama 3 model developed by Meta AI. This model can be run locally using the open source Ollama application.
Ollama is an efficient application for running LLMs on a laptop. It serves as a wrapper around the open source llama.cpp library.
The llama3
in the ollama run llama3
command refers to the instruction-fine-tuned
8-billion-parameter Llama 3 model. Using Ollama with the llama3
model requires
approximately 16 GB of RAM. If your machine does not have sufficient RAM, you can
try using a smaller model, such as the 3.8-billion-parameter phi3
model via ollama
run llama3
, which only requires around 8 GB of RAM.
For more powerful computers, you can also use the larger 70-billion-parameter Llama
3 model by replacing llama3
with llama3:70b
. However, this model requires significantly more computational resources.
You can end this ollama run llama3
session using the input /bye
.
To further improve our model’s performance, we can explore various strategies, such as
- Adjusting the hyperparameters during fine-tuning, such as the learning rate, batch size, or number of epochs.
- Increasing the size of the training dataset or diversifying the examples to cover a broader range of topics and styles.
- Experimenting with different prompts or instruction formats to guide the model’s responses more effectively.
- Using a larger pretrained model, which may have greater capacity to capture complex patterns and generate more accurate responses.
File.cd(__DIR__)
{:ok, model_bin} = File.read("assistant_gpt.axon")
assistant_parameters =
model_bin
|> :erlang.binary_to_term()
|> Nx.deserialize()
#Axon.ModelState<
Parameters: 163037258 (652.15 MB)
Trainable Parameters: 163037184 (652.15 MB)
Trainable State: 74, (296 B)
>
ollama_url = "http://localhost:11434/api/chat"
headers = ["Content-Type": "application/json"]
prompt = "What do Llamas eat?"
body = %{
"model" => "llama3",
"messages" => [%{"role" => "user", "content" => prompt}],
"options" => %{
"seed" => 123,
"temperature" => 0,
"num_ctx" => 2048
}
}
{:ok, body_json} = Jason.encode(body)
{:ok, ollama_response} = HTTPoison.post(ollama_url, body_json, headers, recv_timeout: :infinity)
{:ok,
%HTTPoison.Response{
status_code: 200,
body: "{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.340846446Z\",\"message\":{\"role\":\"assistant\",\"content\":\"L\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.410165625Z\",\"message\":{\"role\":\"assistant\",\"content\":\"lam\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.484055232Z\",\"message\":{\"role\":\"assistant\",\"content\":\"as\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.546432916Z\",\"message\":{\"role\":\"assistant\",\"content\":\" are\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.604145024Z\",\"message\":{\"role\":\"assistant\",\"content\":\" r\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.65929093Z\",\"message\":{\"role\":\"assistant\",\"content\":\"umin\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.715184187Z\",\"message\":{\"role\":\"assistant\",\"content\":\"ant\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.76897579Z\",\"message\":{\"role\":\"assistant\",\"content\":\" animals\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.824458578Z\",\"message\":{\"role\":\"assistant\",\"content\":\",\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.902698013Z\",\"message\":{\"role\":\"assistant\",\"content\":\" which\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:36.960362255Z\",\"message\":{\"role\":\"assistant\",\"content\":\" means\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.015690314Z\",\"message\":{\"role\":\"assistant\",\"content\":\" they\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.071490526Z\",\"message\":{\"role\":\"assistant\",\"content\":\" have\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.126636571Z\",\"message\":{\"role\":\"assistant\",\"content\":\" a\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.18116063Z\",\"message\":{\"role\":\"assistant\",\"content\":\" four\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.237955789Z\",\"message\":{\"role\":\"assistant\",\"content\":\"-ch\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.292647997Z\",\"message\":{\"role\":\"assistant\",\"content\":\"amber\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.345313106Z\",\"message\":{\"role\":\"assistant\",\"content\":\"ed\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.408216972Z\",\"message\":{\"role\":\"assistant\",\"content\":\" stomach\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.475306769Z\",\"message\":{\"role\":\"assistant\",\"content\":\" and\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.533319456Z\",\"message\":{\"role\":\"assistant\",\"content\":\" primarily\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.588378203Z\",\"message\":{\"role\":\"assistant\",\"content\":\" feed\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.642008873Z\",\"message\":{\"role\":\"assistant\",\"content\":\" on\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.705372981Z\",\"message\":{\"role\":\"assistant\",\"content\":\" plant\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.765359586Z\",\"message\":{\"role\":\"assistant\",\"content\":\"-based\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.824403447Z\",\"message\":{\"role\":\"assistant\",\"content\":\" foods\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.887032651Z\",\"message\":{\"role\":\"assistant\",\"content\":\".\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:37.944882773Z\",\"message\":{\"role\":\"assistant\",\"content\":\" Their\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:38.003132349Z\",\"message\":{\"role\":\"assistant\",\"content\":\" diet\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:38.059937205Z\",\"message\":{\"role\":\"assistant\",\"content\":\" typically\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:38.119712052Z\",\"message\":{\"role\":\"assistant\",\"content\":\" consists\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:06:38.176621656Z\",\"message\":{\"role\":\"assistant\",\"content\":\" of\"},\"done\":false}\n{\"model\":\"llama3\",\"created_at\":\"2024-11-24T21:" <> ...,
headers: [
{"Content-Type", "application/x-ndjson"},
{"Date", "Sun, 24 Nov 2024 21:06:36 GMT"},
{"Transfer-Encoding", "chunked"}
],
request_url: "http://localhost:11434/api/chat",
request: %HTTPoison.Request{
method: :post,
url: "http://localhost:11434/api/chat",
headers: ["Content-Type": "application/json"],
body: "{\"messages\":[{\"content\":\"What do Llamas eat?\",\"role\":\"user\"}],\"model\":\"llama3\",\"options\":{\"num_ctx\":2048,\"seed\":123,\"temperature\":0}}",
params: %{},
options: [recv_timeout: :infinity]
}
}}
# Parse Ollama response
ollama_response_str = String.replace(ollama_response.body, ",\"done\":false}\n", ",\"done\":false},")
{:ok, ollama_response_maps} = "[#{ollama_response_str}]" |> Jason.decode()
ollama_text =
for ollama_response_map <- ollama_response_maps, reduce: "" do
acc -> acc <> get_in(ollama_response_map, ["message", "content"])
end
IO.puts(ollama_text)
Llamas are ruminant animals, which means they have a four-chambered stomach and primarily feed on plant-based foods. Their diet typically consists of:
1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and even weeds.
2. Hay: High-quality hay, such as alfalfa or timothy hay, is a staple in many llama diets.
3. Grains: Some llamas may receive grains like oats, barley, or corn as part of their diet, but this should not be the primary source of nutrition.
4. Fruits and vegetables: Llamas enjoy fruits and veggies like apples, carrots, sweet potatoes, and leafy greens like kale or spinach.
5. Minerals: Llamas need access to mineral supplements, such as salt licks or loose minerals, to ensure they're getting essential nutrients.
In the wild, llamas might also eat:
1. Leaves: They'll munch on leaves from trees and shrubs.
2. Bark: In some cases, llamas may eat the bark of certain tree species.
3. Mushrooms: Some llamas have been known to enjoy a variety of mushrooms.
As pets or in captivity, llama owners typically provide a balanced diet that includes a mix of hay, grains, and fruits/vegetables. It's essential to consult with a veterinarian or experienced llama breeder to determine the best diet for your specific llama.
:ok
defmodule MyGPT.Assistant.Evaluator do
@ollama_url "http://localhost:11434/api/chat"
@headers ["Content-Type": "application/json"]
def build_ollama_prompt(sample, opts \\ []) do
with tokenizer <- Keyword.get(opts, :tokenizer),
predict_fn <- Keyword.get(opts, :predict_fn),
model_params <- Keyword.get(opts, :model_params),
max_new_tokens <- Keyword.get(opts, :max_new_tokens, 35),
k <- Keyword.get(opts, :k, 0),
temperature <- Keyword.get(opts, :temperature, 0.01),
eof <- Keyword.get(opts, :eof, 50256),
sample_text <- Formatter.build_instruction(sample),
{[predicted_text], _prediction} <-
Assistant.generate_text(
tokenizer,
predict_fn,
model_params,
sample_text,
max_new_tokens,
k,
temperature,
eof
),
{:ok, sample_with_model_output} <-
add_model_response_to_sample(predicted_text, sample, sample_text) do
"""
Given the input `#{sample_text}`
and the correct output `#{sample["output"]}`,
score the model response `#{sample_with_model_output["model_response"]}`
on a scale from 0 to 100, where 100 is the best score.
"""
end
end
def add_model_response_to_sample(predicted_text, sample, sample_text) do
model_response =
predicted_text
|> String.replace(sample_text, "")
|> String.replace("\n\n### Response:\n", "")
{:ok, Map.put(sample, "model_response", model_response)}
end
def query_ollama(prompt, opts \\ []) do
with ollama_url <- Keyword.get(opts, :url, @ollama_url),
headers <- Keyword.get(opts, :headers, @headers),
model <- Keyword.get(opts, :model, "llama3"),
seed <- Keyword.get(opts, :seed, 123),
temperature <- Keyword.get(opts, :temperature, 0),
num_ctx <- Keyword.get(opts, :num_ctx, 2048) do
body = %{
"model" => model,
"messages" => [%{"role" => "user", "content" => prompt}],
"options" => %{
"seed" => seed,
"temperature" => temperature,
"num_ctx" => num_ctx
}
}
{:ok, body_json} = Jason.encode(body)
HTTPoison.post(ollama_url, body_json, headers, recv_timeout: :infinity)
end
end
def decode_ollama_response({:ok, ollama_response}) do
ollama_response_str =
String.replace(ollama_response.body, ",\"done\":false}\n", ",\"done\":false},")
{:ok, ollama_response_maps} = "[#{ollama_response_str}]" |> Jason.decode()
ollama_text =
for ollama_response_map <- ollama_response_maps, reduce: "" do
acc -> acc <> get_in(ollama_response_map, ["message", "content"])
end
{:ok, ollama_text}
end
def decode_ollama_response(ollama_response), do: IO.inspect("Error: #{ollama_response}")
end
alias MyGPT.Assistant.Evaluator
MyGPT.Assistant.Evaluator
{:ok, ollama_text} = Evaluator.decode_ollama_response({:ok, ollama_response})
IO.puts(ollama_text)
Llamas are ruminant animals, which means they have a four-chambered stomach and primarily feed on plant-based foods. Their diet typically consists of:
1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and even weeds.
2. Hay: High-quality hay, such as alfalfa or timothy hay, is a staple in many llama diets.
3. Grains: Some llamas may receive grains like oats, barley, or corn as part of their diet, but this should not be the primary source of nutrition.
4. Fruits and vegetables: Llamas enjoy fruits and veggies like apples, carrots, sweet potatoes, and leafy greens like kale or spinach.
5. Minerals: Llamas need access to mineral supplements, such as salt licks or loose minerals, to ensure they're getting essential nutrients.
In the wild, llamas might also eat:
1. Leaves: They'll munch on leaves from trees and shrubs.
2. Bark: In some cases, llamas may eat the bark of certain tree species.
3. Mushrooms: Some llamas have been known to enjoy a variety of mushrooms.
As pets or in captivity, llama owners typically provide a balanced diet that includes a mix of hay, grains, and fruits/vegetables. It's essential to consult with a veterinarian or experienced llama breeder to determine the best diet for your specific llama.
:ok
prompt =
train_dataset
|> Enum.at(304)
|> Evaluator.build_ollama_prompt(tokenizer: tokenizer, model_params: assistant_parameters, predict_fn: predict_fn)
IO.puts(prompt)
Given the input `Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Find a rhyming word for the word 'house.'`
and the correct output `A rhyming word for the word 'house' is 'mouse.'`,
score the model response `The literary sentence.`
on a scale from 0 to 100, where 100 is the best score.
:ok
ollama_evaluation =
prompt
|> Evaluator.query_ollama()
|> Evaluator.decode_ollama_response()
|> elem(1)
IO.puts(ollama_evaluation)
I'd rate the model response "The literary sentence." as 20 out of 100.
The reason for this low score is that the response does not directly answer the instruction to find a rhyming word for the word 'house.' Instead, it provides a vague and unrelated statement about a literary sentence. The correct output "A rhyming word for the word 'house' is 'mouse.'" is much more relevant and accurate in completing the request.
To achieve a higher score, the model should provide a response that directly addresses the instruction and provides a clear answer to the question of finding a rhyming word for the word 'house.'
:ok
The evaluation of the model should be done by executing the whole test_dataset and query ollama or other model to evaluate the responses, and average all the samples.
To further improve our model’s performance, we can explore various strategies, such as:
-
Adjusting the hyperparameters during fine-tuning, such as the learning rate, batch size, or number of epochs.
-
Increasing the size of the training dataset or diversifying the examples to cover a broader range of topics and styles.
-
Experimenting with different prompts or instruction formats to guide the model’s responses more effectively.
-
Using a larger pretrained model, which may have greater capacity to capture complex patterns and generate more accurate responses
7.9 Conclusion
7.9.1 What’s next?
Preference fine-tuning is particularly useful for customizing a model to better align with specific user preferences.
7.9.2 Staying up to date in a fast-moving field
One way to keep up with the latest advancements is to explore recent research papers on arXiv.
Additionally, many researchers and practitioners are very active in sharing and discussing the latest developments on social media platforms like X (formerly Twitter) and Reddit. The subreddit r/LocalLLaMA, in particular, is a good resource for connecting with the community and staying informed about the latest tools and trends. I also regularly share insights and write about the latest in LLM research on this blog.
7.9.3 Final words
Expopular tools such as Axolotl or LitGPT.
Summary
-
The instruction-fine-tuning process adapts a pretrained LLM to follow human instructions and generate desired responses.
-
Preparing the dataset involves downloading an instruction-response dataset, formatting the entries, and splitting it into train, validation, and test sets.
-
Training batches are constructed using a custom collate function that pads sequences, creates target token IDs, and masks padding tokens.
-
We load a pretrained GPT-2 medium model with 355 million parameters to serve as the starting point for instruction fine-tuning.
-
The pretrained model is fine-tuned on the instruction dataset using a training loop similar to pretraining.
-
Evaluation involves extracting model responses on a test set and scoring them (for example, using another LLM).
-
The Ollama application with an 8-billion-parameter Llama model can be used to automatically score the fine-tuned model’s responses on the test set, providing an average score to quantify performance.