Powered by AppSignal & Oban Pro

LLM from Scratch 5.3 Losses

llm_from_scratch_5_3_losses.livemd

LLM from Scratch 5.3 Losses

project_path = Path.expand("..", __DIR__)

Mix.install(
  [
    {:llm_scratch, path: project_path},
    {:kino, "~> 0.12"},
    {:kino_vega_lite, "~> 0.1.11"},
    {:vega_lite, "~> 0.1.9"}
  ],
  consolidate_protocols: false
)

Train model

alias LlmScratch.{
  GPTConfig,
  GPTModel,
  GptDatasetV1,
  Training
}

Code.ensure_loaded!(Training)
Code.ensure_loaded!(Nx.Container.LlmScratch.GPTModel)
Protocol.assert_impl!(Nx.Container, GPTModel)

previous_backend = Nx.default_backend()
device = Nx.default_backend(EXLA.Backend)

file_content = File.read!(Path.join(project_path, "the-verdict.txt"))
train_ratio = 0.90
split_idx = trunc(train_ratio * String.length(file_content))
train_data = String.slice(file_content, 0, split_idx)
val_data = String.slice(file_content, split_idx, String.length(file_content) - split_idx)

gpt_config_124m = %GPTConfig{
  vocab_size: 50_257,
  context_length: 256,
  emb_dim: 768,
  n_heads: 12,
  n_layers: 12,
  drop_rate: 0.0,
  qkv_bias: false
}

train_loader =
  GptDatasetV1.create_dataloader_v1(
    raw_text: train_data,
    batch_size: 2,
    max_length: gpt_config_124m.context_length,
    stride: gpt_config_124m.context_length,
    drop_last: true,
    shuffle: true,
    num_workers: 0
  )

val_loader =
  GptDatasetV1.create_dataloader_v1(
    raw_text: val_data,
    batch_size: 2,
    max_length: gpt_config_124m.context_length,
    stride: gpt_config_124m.context_length,
    drop_last: false,
    shuffle: false,
    num_workers: 0
  )

model = GPTModel.new(gpt_config_124m, seed: 123)
optimizer = Training.adamw(0.0004, weight_decay: 0.1)
num_epochs = 10

{_trained_model, train_losses, val_losses, tokens_seen} =
  Training.train_model_simple(
    model,
    train_loader,
    val_loader,
    optimizer,
    device,
    num_epochs,
    5,
    5,
    "Every effort moves you",
    "code-davinci-002"
  )

Nx.default_backend(previous_backend)

%{
  train_loss_points: length(train_losses),
  validation_loss_points: length(val_losses),
  tokens_seen: tokens_seen
}

Plot losses

alias VegaLite, as: Vl

defmodule LossPlot do
  def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses) do
    loss_data =
      [
        {"Training loss", train_losses},
        {"Validation loss", val_losses}
      ]
      |> Enum.flat_map(fn {series, losses} ->
        Enum.zip([epochs_seen, tokens_seen, losses])
        |> Enum.map(fn {epoch, tokens, loss} ->
          %{
            series: series,
            epoch: epoch,
            tokens_seen: tokens,
            loss: loss
          }
        end)
      end)

    token_axis_data =
      Enum.zip([epochs_seen, tokens_seen, train_losses])
      |> Enum.map(fn {epoch, tokens, loss} ->
        %{
          epoch: epoch,
          tokens_seen: tokens,
          loss: loss
        }
      end)

    epoch_layer =
      Vl.new()
      |> Vl.data_from_values(loss_data)
      |> Vl.mark(:line)
      |> Vl.encode_field(:x, "epoch",
        type: :quantitative,
        title: "Epochs",
        axis: [tick_min_step: 1]
      )
      |> Vl.encode_field(:y, "loss", type: :quantitative, title: "Loss")
      |> Vl.encode_field(:color, "series", type: :nominal, title: nil)
      |> Vl.encode_field(:stroke_dash, "series", type: :nominal, title: nil)

    tokens_layer =
      Vl.new()
      |> Vl.data_from_values(token_axis_data)
      |> Vl.mark(:line, opacity: 0)
      |> Vl.encode_field(:x, "tokens_seen",
        type: :quantitative,
        title: "Tokens seen",
        axis: [orient: "top"]
      )
      |> Vl.encode_field(:y, "loss", type: :quantitative, title: "Loss")

    Vl.new(width: 500, height: 300)
    |> Vl.layers([epoch_layer, tokens_layer])
    |> Vl.resolve(:scale, x: :independent, y: :shared)
    |> Kino.VegaLite.new()
  end
end

epochs_seen =
  case length(train_losses) do
    1 ->
      [0.0]

    count ->
      0..(count - 1)
      |> Enum.map(fn step ->
        step * num_epochs / (count - 1)
      end)
  end

LossPlot.plot_losses(epochs_seen, tokens_seen, train_losses, val_losses)