Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Fine-tuning

elixir/learn_livebook/fine_tunnig.livemd

Fine-tuning

Mix.install([
  {:bumblebee, "~> 0.3.0"},
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.1"},
  {:exla, "~> 0.5.1"},
  {:explorer, "~> 0.5.0"}
])
Resolving Hex dependencies...
Resolution completed in 0.174s
New:
  axon 0.5.1
  bumblebee 0.3.0
  castore 1.0.2
  complex 0.5.0
  decimal 2.1.1
  elixir_make 0.7.6
  exla 0.5.3
  explorer 0.5.7
  jason 1.4.0
  nx 0.5.3
  nx_image 0.1.1
  nx_signal 0.1.0
  progress_bar 2.0.1
  rustler_precompiled 0.6.1
  table 0.1.2
  table_rex 3.1.1
  telemetry 1.2.1
  tokenizers 0.3.2
  unpickler 0.1.0
  unzip 0.8.0
  xla 0.4.4
* Getting bumblebee (Hex package)
* Getting axon (Hex package)
* Getting nx (Hex package)
* Getting exla (Hex package)
* Getting explorer (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table (Hex package)
* Getting table_rex (Hex package)
* Getting castore (Hex package)
* Getting elixir_make (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting jason (Hex package)
* Getting nx_image (Hex package)
* Getting nx_signal (Hex package)
* Getting progress_bar (Hex package)
* Getting tokenizers (Hex package)
* Getting unpickler (Hex package)
* Getting unzip (Hex package)
* Getting decimal (Hex package)
==> unzip
Compiling 5 files (.ex)
Generated unzip app
==> decimal
Compiling 4 files (.ex)
Generated decimal app
==> progress_bar
Compiling 10 files (.ex)
Generated progress_bar app
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> jason
Compiling 10 files (.ex)
Generated jason app
==> unpickler
Compiling 3 files (.ex)
Generated unpickler app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> nx_image
Compiling 1 file (.ex)
Generated nx_image app
==> nx_signal
Compiling 3 files (.ex)
Generated nx_signal app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> explorer
Compiling 19 files (.ex)

19:52:00.047 [debug] Copying NIF from cache and extracting to /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/_build/dev/lib/explorer/priv/native/libexplorer-v0.5.7-nif-2.16-aarch64-apple-darwin.so
Generated explorer app
==> tokenizers
Compiling 7 files (.ex)

19:52:01.034 [debug] Copying NIF from cache and extracting to /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/_build/dev/lib/tokenizers/priv/native/libex_tokenizers-v0.3.2-nif-2.16-aarch64-apple-darwin.so
Generated tokenizers app
==> bumblebee
Compiling 89 files (.ex)
Generated bumblebee app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/sn/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/deps/exla/cache
Using libexla.so from /Users/sn/Library/Caches/xla/exla/elixir-1.14.4-erts-13.2.1-xla-0.4.4-exla-0.5.3-yywvdsvg3ykdeecnbffiqglm6i/libexla.so
Compiling 21 files (.ex)
Generated exla app
:ok

Introduction

Nx.default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}

Fine-tuning is the process of specializing the parameters in a pre-trained model to a specific task. Large-language models such as BERT train on a generic langauge-modeling task which makes them powerful at extracting features from text. Despite their power, you often still need to train them on a downstream task.

This example demonstrates how to use Bumblebee and Axon to fine-tune a pre-trained Bert model to classify Yelp reviews into classes of 1-5 stars. This example is based on Fine-tune a pretrained model from Huggingface.

You’ll need to first download the Yelp Reviews dataset (download).

Once downloaded, extract it to a directory of your choosing and you’re ready to go!

Load a model

We’ll start by loading a pre-trained model and tokenizer; however, we’ll initialize the model to have an untrained sequence classification head.

Reviews in the dataset can have anywhere from 1 to 5 stars, which means we need 5 labels in our sequence classification head. We can change the default configuration by loading the model spec with Bumblebee.load_spec/2 and making changes to spec properties with Bumblebee.configure/2.

The pre-trained model we’ll be using is bert-base-cased; however, you can use any of the supported models from the HuggingFace Hub.

{:ok, spec} =
  Bumblebee.load_spec({:hf, "bert-base-cased"},
    architecture: :for_sequence_classification
  )

spec = Bumblebee.configure(spec, num_labels: 5)

{:ok, model} = Bumblebee.load_model({:hf, "bert-base-cased"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"})
|=============================================================| 100% (435.77 MB)

23:46:09.618 [info] TfrtCpuClient created.

23:46:10.748 [debug] the following parameters were missing:

  * sequence_classification_head.output.kernel
  * sequence_classification_head.output.bias


23:46:10.748 [debug] the following PyTorch parameters were unused:

  * cls.predictions.bias
  * cls.predictions.decoder.weight
  * cls.predictions.transform.LayerNorm.beta
  * cls.predictions.transform.LayerNorm.gamma
  * cls.predictions.transform.dense.bias
  * cls.predictions.transform.dense.weight
  * cls.seq_relationship.bias
  * cls.seq_relationship.weight

{:ok,
 %Bumblebee.Text.BertTokenizer{
   tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 28996,
     continuing_subword_prefix: "##",
     max_input_chars_per_word: 100,
     model_type: "bpe",
     unk_token: "[UNK]"
   ]>,
   special_tokens: %{cls: "[CLS]", mask: "[MASK]", pad: "[PAD]", sep: "[SEP]", unk: "[UNK]"}
 }}

Prepare a dataset

With the models downloaded and ready to go, you need to prepare the dataset. The downloaded dataset is a CSV. You can use the Explorer library to quickly load the CSV into a DataFrame.

Once the data is loaded, you need to convert raw text to tokens and the raw labels to tensors. Additionally, you need to convert the DataFrame to a Stream consisting of tuples: {tokenized, labels} - that is the form expected by Axon’s training API.

defmodule Yelp do
  def load(path, tokenizer, opts \\ []) do
    path
    |> Explorer.DataFrame.from_csv!(header: false)
    |> Explorer.DataFrame.rename(["label", "text"])
    |> stream()
    |> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
  end

  def stream(df) do
    xs = df["text"]
    ys = df["label"]

    xs
    |> Explorer.Series.to_enum()
    |> Stream.zip(Explorer.Series.to_enum(ys))
  end

  def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
    stream
    |> Stream.chunk_every(batch_size)
    |> Stream.map(fn batch ->
      {text, labels} = Enum.unzip(batch)
      tokenized = Bumblebee.apply_tokenizer(tokenizer, text, length: sequence_length)
      {tokenized, Nx.stack(labels)}
    end)
  end
end
{:module, Yelp, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:tokenize_and_batch, 4}}

Now you can use the Yelp.load/2 function to load a training set and a testing set:

batch_size = 32
sequence_length = 64

train_data =
  Yelp.load("/Users/sn/Downloads/yelp_review_full_csv/train.csv", tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )

test_data =
  Yelp.load("/Users/sn/Downloads/yelp_review_full_csv/test.csv", tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )
#Stream<[
  enum: #Stream<[
    enum: #Function<73.57817549/2 in Stream.zip_with/2>,
    funs: [#Function<3.57817549/1 in Stream.chunk_while/4>]
  ]>,
  funs: [#Function<48.57817549/1 in Stream.map/2>]
]>

You can see what a single batch looks like by grabbing 1 from the stream:

Enum.take(train_data, 1)
[
  {%{
     "attention_mask" => #Nx.Tensor<
       u32[32][64]
       EXLA.Backend
       [
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...],
         ...
       ]
     >,
     "input_ids" => #Nx.Tensor<
       u32[32][64]
       EXLA.Backend
       [
         [101, 173, 1197, 119, 2284, 2953, 3272, 1917, 178, 1440, 1111, 1107, 170, 1704, 22351, 119, 1119, 112, 188, 3505, 1105, 3123, 1106, 2037, 1106, 1443, 1217, 10063, 4404, 132, 1119, 112, 188, 1579, 1113, 1159, 1107, 3195, 1117, 4420, 132, 1119, 112, 188, 6559, 1114, ...],
         ...
       ]
     >,
     "token_type_ids" => #Nx.Tensor<
       u32[32][64]
       EXLA.Backend
       [
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[32]
     EXLA.Backend
     [5, 2, 4, 4, 1, 5, 5, 1, 2, 3, 1, 1, 4, 2, 5, 5, 5, 5, 5, 5, 4, 3, 2, 5, 1, 1, 1, 2, 2, 4, 2, 2]
   >}
]

The dataset is rather large for CPU training, so we’ll just train a small subset (250 training batches and 50 testing batches):

train_data = Enum.take(train_data, 250)
test_data = Enum.take(test_data, 50)
:ok
:ok

Train the model

Now we can go about training the model! First, we need to extract the Axon model and parameters from the Bumblebee model map:

%{model: model, params: params} = model

model
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
  outputs: "container_37"
  nodes: 790
>

The Axon model actually outputs a map with :logits, :hidden_states, and :attentions. You can see this by using Axon.get_output_shape/2 with an input. This method symbolically executes the graph and gets the resulting shapes:

[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
%{attentions: #Axon.None<...>, hidden_states: #Axon.None<...>, logits: {32, 5}}

For training, we only care about the :logits key, so we’ll extract that by attaching an Axon.nx/2 layer to the model:

logits_model = Axon.nx(model, &amp; &amp;1.logits)
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
  outputs: "nx_0"
  nodes: 791
>

Now we can declare our training loop. You can construct Axon training loops using the Axon.Loop.trainer/3 factory method with a model, loss function, and optimizer. We’ll also adjust the log-settings to more frequently log metrics to standard out:

loss =
  &amp;Axon.Losses.categorical_cross_entropy(&amp;1, &amp;2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

optimizer = Axon.Optimizers.adam(5.0e-5)

loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

The call to trainer just returns a data structure. In Axon, we manipulate this data structure to control different parts of the loop. For example, you can attach metrics:

accuracy = &amp;Axon.Metrics.accuracy(&amp;1, &amp;2, from_logits: true, sparse: true)

loop = Axon.Loop.metric(loop, accuracy, "accuracy")
#Axon.Loop<
  metrics: %{
    "accuracy" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>},
    "loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

And you can attach event handlers to do certain things, such as serialize the loop state at regular intervals so you don’t lose your progress:

loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
#Axon.Loop<
  metrics: %{
    "accuracy" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>},
    "loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<17.24288319/1 in Axon.Loop.checkpoint/2>,
       #Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>},
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.24288319/1 in Axon.Loop.log/3>,
       #Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

To run the loop, you just need to call Axon.Loop.run/4. Axon.Loop.run/4 takes a loop, input data, and any initial state (in this case initial parameters). You can kind of think of Axon.Loop.run/4 as an Enum.reduce/3. It takes data, an accumulator, and a function - which map to Loop.run/4 input data, initial state, and the actual loop data structure.

You’ll commonly see loops written out in long chains using Elixir’s |> operator, like this:

trained_model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)

:ok

00:04:39.518 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 249, accuracy: 0.3548750 loss: 1.2023211
Epoch: 1, Batch: 77, accuracy: 0.4431090 loss: 1.1546737

Evaluating the model

The training loop returns the final model state after training over your dataset for the given number of epochs. Axon uses the same Axon.Loop API to create evaluation loops as well. You can create one with the Axon.Loop.evaluator/1 factory, instrument it with metrics, and run it on your data with your trained model state:

logits_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
Batch: 49, accuracy: 0.3675000
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.36750003695487976
    >
  }
}