Powered by AppSignal & Oban Pro

Softmax Shootout: Softmax vs SSMax vs Softpick

notebooks/softmax_shootout.livemd

Softmax Shootout: Softmax vs SSMax vs Softpick

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server). Uncomment the EXLA lines for GPU acceleration.

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino_vega_lite, "~> 0.1"},
  {:kino, "~> 0.14"}
])

# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh. See the Architecture Zoo notebook for full setup instructions.

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")

Introduction

Every transformer model uses attention to decide which parts of a sequence matter most for the current prediction. At the core of attention is a normalization function that turns raw relevance scores into weights. The function everybody knows is softmax — but it’s not the only option, and it has a well-known weakness: as sequences get longer, softmax spreads attention thinner and thinner, making it harder for the model to focus.

This notebook puts 3 normalization functions through the same gauntlet — same model, same data, same training loop — so you can see exactly how they differ in behavior, learning speed, and final quality.

What you’ll learn:

  • What softmax, SSMax, and Softpick actually compute — with formulas, intuition, and interactive visualizations you can poke at
  • Why standard softmax struggles with long sequences (the “dilution problem”) and how SSMax was designed to fix it
  • How Softpick takes a completely different approach — no exponentials at all — and what tradeoffs that creates
  • How to run a fair architecture comparison where the only variable is the normalization function
  • When you might choose one over another in real projects

The 3 contenders:

Function Formula Key Property
Softmax exp(x_i) / sum(exp(x_j)) The standard — all positive, sums to 1
SSMax exp(x_i - s*log(n)) / sum(exp(x_j - s*log(n))) Adapts to sequence length
Softpick x_i / (1 + sum(|x_j|)) No exponentials, preserves sign

What Are Attention Scores?

Before we compare the normalization functions, let’s understand what they’re normalizing. In a transformer, attention works in three steps:

  1. Compute scores: For each position in the sequence, compute a “relevance score” against every other position. High score = “this position is relevant to me.” These scores are just dot products of learned query and key vectors.

  2. Normalize scores: Turn raw scores into weights that the model uses to combine information. This is where softmax/SSMax/Softpick come in.

  3. Weighted sum: Use the weights to combine value vectors — positions with higher weights contribute more to the output.

The normalization step is critical because raw dot-product scores can be any real number (positive or negative, large or small). The normalizer must turn them into something useful. Different normalizers make different choices about what “useful” means:

  • Softmax: “Weights must be positive and sum to 1” (a probability distribution)
  • SSMax: “Same as softmax, but adjust sharpness based on how long the sequence is”
  • Softpick: “Weights should be bounded but can be negative, and don’t need to sum to 1”

Visualize the Functions

Let’s start by seeing what these functions actually do to a vector of scores. No training needed here — just raw computation and plotting.

We’ll create a vector of scores (as if a position in a sequence has computed how relevant every other position is) and pass it through all three functions.

What to look for: Pay attention to how each function distributes weight. Softmax is “winner-take-most” — the highest score gets a disproportionately large weight because of the exponential. SSMax is similar but can be sharper or softer depending on its learned parameter. Softpick is more linear — differences in scores map more proportionally to differences in weights.

# Create a sample attention score vector — imagine these are
# the relevance scores from one query position to 10 key positions.
# Most scores are low (noise), but positions 3 and 7 have high scores (signal).
raw_scores = Nx.tensor([0.5, -0.3, 0.1, 2.5, -0.1, 0.3, 0.0, 2.0, -0.5, 0.2])
seq_len = 10

IO.puts("Raw attention scores: #{inspect(Nx.to_list(raw_scores))}")
IO.puts("These represent how relevant 10 positions are to one query position.")
IO.puts("Positions 3 and 7 have high scores (2.5 and 2.0) — they're the signal.\n")

# 1. Standard softmax: exp(x_i) / sum(exp(x_j))
# This is what every vanilla transformer uses. It turns scores into a
# probability distribution — all positive values that sum to exactly 1.
softmax_weights = Axon.Activations.softmax(raw_scores)

# 2. SSMax with s=1.0 (default initialization)
# SSMax subtracts s*log(n) from every score before applying softmax.
# This makes the exponentials smaller, which counteracts the dilution
# that happens when you have many positions to attend to.
# Here n=10 (sequence length), so the shift is 1.0 * log(10) ≈ 2.3
ssmax_weights = Edifice.Blocks.SSMax.compute(raw_scores, 1.0, seq_len)

# 3. Softpick: x_i / (1 + sum(|x_j|))
# No exponentials at all! Just divide each score by the total magnitude.
# This means negative scores stay negative — a fundamentally different
# approach where attention weights can say "this position is anti-relevant."
softpick_weights = Edifice.Blocks.Softpick.compute(raw_scores)

IO.puts("Softmax weights:  #{inspect(Nx.to_list(softmax_weights) |> Enum.map(&Float.round(&1, 4)))}")
IO.puts("Sum = #{Nx.sum(softmax_weights) |> Nx.to_number() |> Float.round(4)}")
IO.puts("")
IO.puts("SSMax weights:    #{inspect(Nx.to_list(ssmax_weights) |> Enum.map(&Float.round(&1, 4)))}")
IO.puts("Sum = #{Nx.sum(ssmax_weights) |> Nx.to_number() |> Float.round(4)}")
IO.puts("")
IO.puts("Softpick weights: #{inspect(Nx.to_list(softpick_weights) |> Enum.map(&Float.round(&1, 4)))}")
IO.puts("Sum = #{Nx.sum(softpick_weights) |> Nx.to_number() |> Float.round(4)}")
IO.puts("")
IO.puts("Notice: Softmax and SSMax sum to 1.0 (probability distributions).")
IO.puts("Softpick does NOT sum to 1 — and it can have negative values!")

Now let’s visualize these side by side. A chart makes the differences immediately obvious.

# Build visualization data — one row per (function, position) pair
position_labels = Enum.map(0..9, &"pos #{&1}")

viz_data =
  Enum.flat_map(0..9, fn i ->
    [
      %{
        "Position" => Enum.at(position_labels, i),
        "Weight" => Nx.to_number(softmax_weights[i]) |> Float.round(4),
        "Function" => "Softmax",
        "Raw Score" => Nx.to_number(raw_scores[i]) |> Float.round(2)
      },
      %{
        "Position" => Enum.at(position_labels, i),
        "Weight" => Nx.to_number(ssmax_weights[i]) |> Float.round(4),
        "Function" => "SSMax (s=1.0)",
        "Raw Score" => Nx.to_number(raw_scores[i]) |> Float.round(2)
      },
      %{
        "Position" => Enum.at(position_labels, i),
        "Weight" => Nx.to_number(softpick_weights[i]) |> Float.round(4),
        "Function" => "Softpick",
        "Raw Score" => Nx.to_number(raw_scores[i]) |> Float.round(2)
      }
    ]
  end)

Vl.new(width: 700, height: 350, title: "Attention Weight Distribution — Same Scores, Different Normalizers")
|> Vl.data_from_values(viz_data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "Position", type: :nominal, sort: position_labels, title: "Sequence Position")
|> Vl.encode_field(:y, "Weight", type: :quantitative, title: "Attention Weight")
|> Vl.encode_field(:color, "Function", type: :nominal)
|> Vl.encode_field(:x_offset, "Function", type: :nominal)

Concentrated vs Diffuse Inputs

The differences become more dramatic with different input distributions. Let’s try a “concentrated” input (one position dominates) and a “diffuse” input (all positions are similar).

What to look for: With concentrated inputs, all three functions should agree — when one position clearly dominates, any normalizer will put most weight there. With diffuse inputs, the functions diverge: softmax will spread weight nearly uniformly, SSMax will try to maintain some sharpness, and Softpick will preserve the relative ordering more faithfully.

# Concentrated: one position has a very high score, rest are near zero.
# This simulates "one token is clearly the most important."
concentrated = Nx.tensor([0.1, 0.0, 0.2, 5.0, 0.1, -0.1, 0.0, 0.1, 0.0, 0.1])

# Diffuse: all positions have similar scores.
# This simulates "everything is equally relevant" — the hard case for softmax.
diffuse = Nx.tensor([1.0, 0.9, 1.1, 1.0, 0.8, 1.2, 0.9, 1.1, 1.0, 0.95])

# Compute all 6 combinations (2 inputs x 3 functions)
scenarios = [
  {"Concentrated", concentrated},
  {"Diffuse", diffuse}
]

charts =
  Enum.map(scenarios, fn {scenario_name, scores} ->
    sm = Axon.Activations.softmax(scores)
    ss = Edifice.Blocks.SSMax.compute(scores, 1.0, 10)
    sp = Edifice.Blocks.Softpick.compute(scores)

    data =
      Enum.flat_map(0..9, fn i ->
        [
          %{"Position" => "pos #{i}", "Weight" => Nx.to_number(sm[i]) |> Float.round(4),
            "Function" => "Softmax"},
          %{"Position" => "pos #{i}", "Weight" => Nx.to_number(ss[i]) |> Float.round(4),
            "Function" => "SSMax"},
          %{"Position" => "pos #{i}", "Weight" => Nx.to_number(sp[i]) |> Float.round(4),
            "Function" => "Softpick"}
        ]
      end)

    Vl.new(width: 350, height: 250, title: "#{scenario_name} Inputs")
    |> Vl.data_from_values(data)
    |> Vl.mark(:bar)
    |> Vl.encode_field(:x, "Position", type: :nominal, sort: position_labels)
    |> Vl.encode_field(:y, "Weight", type: :quantitative)
    |> Vl.encode_field(:color, "Function", type: :nominal)
    |> Vl.encode_field(:x_offset, "Function", type: :nominal)
  end)

Vl.new()
|> Vl.concat(charts, :horizontal)

Sequence Length Sensitivity

This is the key insight that motivated SSMax: standard softmax becomes increasingly diffuse (spreads attention more uniformly) as the sequence gets longer. Here’s why:

Imagine you have 10 positions and one has a slightly higher score. Softmax will give it a reasonable share of the total weight. Now imagine you have 1000 positions — that same score advantage gets diluted because there are so many more positions dividing up the probability mass.

SSMax was designed specifically to counteract this. Its formula includes s * log(n) where n is the sequence length. As the sequence gets longer, this term grows, effectively “sharpening” the attention distribution to compensate for the dilution.

What to look for: In the chart below, watch what happens to the maximum attention weight as sequence length increases. For softmax, the max weight should steadily decrease (attention gets more uniform). For SSMax with s=1, the max weight should remain more stable. Softpick’s behavior depends on the score magnitudes relative to the total absolute sum.

# Test how the maximum attention weight changes as sequence length grows.
# We keep the same "signal" pattern — one high-scoring position among noise —
# and add more noise positions as the sequence gets longer.
seq_lengths = [8, 16, 32, 64, 128, 256]

length_data =
  Enum.flat_map(seq_lengths, fn n ->
    # Create scores: one "signal" position (score=3.0), rest are small random-ish noise.
    # We use a deterministic pattern so results are reproducible.
    noise = Nx.iota({n}) |> Nx.multiply(0.1) |> Nx.cos() |> Nx.multiply(0.5)
    # Put a strong signal at position 0
    signal = Nx.broadcast(3.0, {1})
    scores = Nx.concatenate([signal, noise[1..-1//1]])

    # Compute each normalizer
    sm = Axon.Activations.softmax(scores)
    ss = Edifice.Blocks.SSMax.compute(scores, 1.0, n)
    sp = Edifice.Blocks.Softpick.compute(scores)

    # Record the max weight — how much attention goes to the "signal" position
    [
      %{"Seq Length" => n,
        "Max Weight" => Nx.reduce_max(sm) |> Nx.to_number() |> Float.round(4),
        "Function" => "Softmax"},
      %{"Seq Length" => n,
        "Max Weight" => Nx.reduce_max(ss) |> Nx.to_number() |> Float.round(4),
        "Function" => "SSMax (s=1.0)"},
      %{"Seq Length" => n,
        "Max Weight" => Nx.reduce_max(sp) |> Nx.to_number() |> Float.round(4),
        "Function" => "Softpick"}
    ]
  end)

IO.puts("How well can each function focus on a signal position as sequence length grows?")
IO.puts("Higher max weight = better focus on the important position.\n")

# Print a quick table
IO.puts(
  String.pad_trailing("  Seq Len", 12) <>
  String.pad_trailing("Softmax", 12) <>
  String.pad_trailing("SSMax", 12) <>
  "Softpick"
)
IO.puts("  " <> String.duplicate("-", 44))

Enum.chunk_every(length_data, 3)
|> Enum.each(fn [sm, ss, sp] ->
  IO.puts(
    String.pad_trailing("  #{sm["Seq Length"]}", 12) <>
    String.pad_trailing("#{sm["Max Weight"]}", 12) <>
    String.pad_trailing("#{ss["Max Weight"]}", 12) <>
    "#{sp["Max Weight"]}"
  )
end)

Vl.new(width: 600, height: 350,
  title: "Max Attention Weight vs Sequence Length (higher = better focus)")
|> Vl.data_from_values(length_data)
|> Vl.mark(:line, point: true, stroke_width: 2)
|> Vl.encode_field(:x, "Seq Length", type: :quantitative, scale: %{type: "log"}, title: "Sequence Length (log scale)")
|> Vl.encode_field(:y, "Max Weight", type: :quantitative, title: "Max Attention Weight")
|> Vl.encode_field(:color, "Function", type: :nominal)
|> Vl.encode_field(:stroke_dash, "Function", type: :nominal)

Build the Contenders

Now we build 3 transformer models that are identical in every way except the normalization function used inside attention. Same hidden size, same number of heads, same number of layers, same dropout — the only variable is whether attention weights come from softmax, SSMax, or softpick.

This is what makes the comparison fair: any difference in performance must come from the normalization function, not from model size or architecture choices.

Each model takes input shaped [batch, seq_len, embed_dim] — a batch of sequences where each position is represented by an embed_dim-dimensional vector. The model outputs [batch, hidden_size] — one vector per sequence that summarizes the model’s understanding. We then add a classification head on top that maps this to class probabilities.

# Shared hyperparameters — identical for every model.
# We use small values so training is fast on CPU or GPU.
hidden_size = 64       # Internal dimension of the transformer
num_heads = 4          # Number of attention heads (each head attends independently)
num_layers = 2         # How many transformer blocks to stack
seq_len = 32           # Length of input sequences
num_classes = 4        # Number of classes for our synthetic task
dropout = 0.05         # Small amount of regularization

# These options are shared by all 3 models
shared_opts = [
  embed_dim: hidden_size,  # Input already matches hidden_size (no projection needed)
  hidden_size: hidden_size,
  num_heads: num_heads,
  num_layers: num_layers,
  window_size: seq_len,
  dropout: dropout
]

# Build all 3 models using Edifice.build — the same API used throughout Edifice.
# The only thing that changes is the atom: :attention, :ssmax, or :softpick.

# 1. Standard softmax attention — the baseline
softmax_model =
  Edifice.build(:attention, shared_opts)
  |> Axon.dense(num_classes, name: "classifier_softmax")

# 2. SSMax attention — sequence-length-aware softmax
ssmax_model =
  Edifice.build(:ssmax, shared_opts)
  |> Axon.dense(num_classes, name: "classifier_ssmax")

# 3. Softpick attention — non-saturating, sign-preserving normalization
softpick_model =
  Edifice.build(:softpick, shared_opts)
  |> Axon.dense(num_classes, name: "classifier_softpick")

IO.puts("Built 3 models with identical architecture:")
IO.puts("  - hidden_size: #{hidden_size}")
IO.puts("  - num_heads: #{num_heads}")
IO.puts("  - num_layers: #{num_layers}")
IO.puts("  - seq_len: #{seq_len}")
IO.puts("  - num_classes: #{num_classes}")
IO.puts("\nThe ONLY difference is the attention normalization function.")
IO.puts("Each model: input {batch, #{seq_len}, #{hidden_size}} -> output {batch, #{num_classes}}")

Generate Training Data

We need a task where attention quality actually matters. If the task were too easy (like “is the first element positive?”), any normalizer would work fine. We want a task that requires the model to find and attend to specific positions in the sequence.

Our synthetic task: Each sequence has a hidden “signal” — a specific pattern of values placed at a random position in the sequence. The rest of the sequence is noise. The signal pattern determines the class label (0, 1, 2, or 3). The model must:

  1. Scan the entire sequence to find where the signal is
  2. Focus attention on those positions (ignore the noise)
  3. Classify based on what the signal pattern contains

This directly tests the quality of attention: a model with better normalization should find and focus on the signal more effectively.

# Synthetic data generator for sequence classification.
# Creates sequences where a "signal" pattern (determining the class) is
# embedded at a random position within noise.

defmodule SignalData do
  @moduledoc """
  Generates synthetic classification data that requires attention to solve.

  Each sequence is mostly noise, with a small "signal" pattern hidden at
  a random position. The signal pattern determines the class label.
  The model must learn to find and attend to the signal to classify correctly.
  """

  @doc """
  Generate a dataset of sequences with hidden signals.

  ## Parameters
    - `n` - Number of sequences to generate
    - `seq_len` - Length of each sequence
    - `dim` - Dimensionality of each position's feature vector
    - `num_classes` - Number of distinct signal patterns (classes)
    - `signal_len` - Length of the signal pattern within the sequence

  ## Returns
    `{x, y}` where:
    - `x` has shape `{n, seq_len, dim}` — the input sequences
    - `y` has shape `{n}` — integer class labels
  """
  def generate(n, seq_len, dim, num_classes, signal_len \\ 3) do
    key = Nx.Random.key(42)

    # Step 1: Generate random noise for all sequences.
    # This is the "background" that the model must learn to ignore.
    {noise, key} = Nx.Random.normal(key, 0.0, 0.3, shape: {n, seq_len, dim})

    # Step 2: Create distinct signal patterns — one per class.
    # Each signal is a small block of specific values that the model
    # must learn to recognize. We make them distinct enough that a
    # model with good attention should be able to tell them apart.
    signal_patterns =
      Enum.map(0..(num_classes - 1), fn class_id ->
        # Each class gets a unique pattern: different dimensions are "activated"
        # This creates orthogonal-ish patterns that are clearly distinguishable
        pattern = Nx.broadcast(0.0, {signal_len, dim})

        # Activate different dimension ranges for each class
        start_dim = rem(class_id * div(dim, num_classes), dim)
        activation = 2.0 * (1 + class_id * 0.3)

        Enum.reduce(0..(signal_len - 1), pattern, fn t, acc ->
          # Each timestep of the signal has a slightly different activation
          dim_idx = rem(start_dim + t * 3, dim)
          indices = Nx.tensor([[t, dim_idx]])
          updates = Nx.tensor([activation])
          Nx.indexed_put(acc, indices, updates)
        end)
      end)

    # Step 3: For each sequence, pick a random class and a random position,
    # then insert that class's signal pattern at that position.
    {class_labels, key} = Nx.Random.randint(key, 0, num_classes, shape: {n}, type: :s32)
    max_start = seq_len - signal_len
    {positions, _key} = Nx.Random.randint(key, 0, max_start + 1, shape: {n}, type: :s32)

    # Step 4: Insert signals into the noise sequences.
    # This is the tricky part — we place each class's signal at a random position.
    x =
      Enum.reduce(0..(n - 1), noise, fn i, acc ->
        class_id = Nx.to_number(class_labels[i])
        pos = Nx.to_number(positions[i])
        signal = Enum.at(signal_patterns, class_id)

        # Insert signal at the chosen position
        Enum.reduce(0..(signal_len - 1), acc, fn t, inner_acc ->
          # For each timestep of the signal, add it to the noise at pos+t
          signal_row = signal[t]
          # Use indexed_add-style: read the current row, add signal, put it back
          current = inner_acc[i][pos + t]
          new_row = Nx.add(current, signal_row)
          indices = Nx.tensor([[i, pos + t]])
          # We need to update the full sequence — reshape for indexed_put
          flat_idx = i * seq_len + (pos + t)
          flat_acc = Nx.reshape(inner_acc, {:auto, dim})
          flat_acc = Nx.put_slice(flat_acc, [flat_idx, 0], Nx.reshape(new_row, {1, dim}))
          Nx.reshape(flat_acc, {n, seq_len, dim})
        end)
      end)

    {x, class_labels}
  end
end

# Generate training and test data
n_train = 800
n_test = 200
signal_len = 3  # Signal is only 3 positions long in a 32-position sequence

IO.puts("Generating synthetic sequence classification data...")
IO.puts("  Training samples: #{n_train}")
IO.puts("  Test samples: #{n_test}")
IO.puts("  Sequence length: #{seq_len}")
IO.puts("  Feature dimension: #{hidden_size}")
IO.puts("  Number of classes: #{num_classes}")
IO.puts("  Signal length: #{signal_len} (hidden in #{seq_len} positions of noise)")
IO.puts("")

{train_x, train_y} = SignalData.generate(n_train, seq_len, hidden_size, num_classes, signal_len)
{test_x, test_y} = SignalData.generate(n_test, seq_len, hidden_size, num_classes, signal_len)

IO.puts("Train X shape: #{inspect(Nx.shape(train_x))}  (samples, seq_len, features)")
IO.puts("Train Y shape: #{inspect(Nx.shape(train_y))}  (samples,) — class labels 0-#{num_classes - 1}")
IO.puts("Test X shape:  #{inspect(Nx.shape(test_x))}")
IO.puts("Test Y shape:  #{inspect(Nx.shape(test_y))}")

# Show class distribution to verify balance
IO.puts("\nClass distribution (training):")
Enum.each(0..(num_classes - 1), fn c ->
  count = Nx.equal(train_y, c) |> Nx.sum() |> Nx.to_number()
  IO.puts("  Class #{c}: #{count} samples")
end)

Now we prepare the data for training. We one-hot encode the class labels (turning class 2 into [0, 0, 1, 0]) and batch everything for efficient GPU processing.

batch_size = 32

# One-hot encode targets: class 2 -> [0, 0, 1, 0]
# Neural networks output a score for each class, so we need targets
# in the same format for the loss function to compare against.
train_y_onehot =
  train_y
  |> Nx.reshape({n_train, 1})
  |> Nx.equal(Nx.iota({1, num_classes}))
  |> Nx.as_type(:f32)

test_y_onehot =
  test_y
  |> Nx.reshape({n_test, 1})
  |> Nx.equal(Nx.iota({1, num_classes}))
  |> Nx.as_type(:f32)

# Batch training data — the model processes batch_size sequences at once.
# This is more efficient than processing one at a time because GPUs are
# designed for parallel computation.
train_data =
  Enum.zip(
    Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
    Nx.to_batched(train_y_onehot, batch_size) |> Enum.to_list()
  )

IO.puts("Prepared #{length(train_data)} training batches of size #{batch_size}")
IO.puts("Target encoding: one-hot vectors of size #{num_classes}")
IO.puts("  Example: class 2 -> #{inspect(Nx.to_list(test_y_onehot[0]))}")

Shared Training Infrastructure

All 3 models use the exact same training and evaluation code. The only thing that changes is which Axon model graph we pass in. This is essential for a fair comparison — we don’t want differences in optimizer settings, learning rates, or evaluation code to bias the results.

How training works (for newcomers):

  1. Forward pass: Feed a batch of sequences through the model, get predicted class probabilities
  2. Compute loss: Compare predictions to true labels using cross-entropy loss. Cross-entropy measures how “surprised” the model is by the correct answer — lower loss means the model’s predictions are closer to reality
  3. Backward pass: Compute gradients — how should each parameter change to reduce the loss?
  4. Update parameters: Use the Adam optimizer to adjust parameters in the direction that reduces loss
  5. Repeat: Do this for every batch, for every epoch
defmodule ShootoutTrainer do
  @moduledoc """
  Shared training and evaluation for the softmax shootout.

  Every model gets identical treatment: same loss function, same optimizer,
  same learning rate, same number of epochs. The only variable is the
  model architecture (specifically, the attention normalization function).
  """

  @doc """
  Train a model and collect per-epoch metrics.

  Returns a map with:
    - `:state` — the trained model parameters
    - `:losses` — list of mean training loss per epoch
    - `:time_s` — wall-clock training time in seconds
    - `:name` — display name of this model
  """
  def train(model, train_data, opts \\ []) do
    epochs = Keyword.get(opts, :epochs, 10)
    lr = Keyword.get(opts, :lr, 3.0e-4)
    name = Keyword.get(opts, :name, "model")

    # Cross-entropy loss: the standard loss for classification.
    # It measures how well the model's predicted probability distribution
    # matches the true label. targets (&1) first, predictions (&2) second.
    loss_fn = &amp;Axon.Losses.categorical_cross_entropy(&amp;1, &amp;2,
      from_logits: true, reduction: :mean)

    # We track loss per epoch using the process dictionary.
    # This lets us build loss curves for visualization later.
    Process.put(:epoch_losses, [])

    start_time = System.monotonic_time(:millisecond)

    state =
      model
      |> Axon.Loop.trainer(loss_fn,
        Polaris.Optimizers.adam(learning_rate: lr), log: 0)
      |> Axon.Loop.handle_event(:epoch_completed, fn loop_state ->
        loss_val = Nx.to_number(loop_state.metrics["loss"])
        epoch_num = length(Process.get(:epoch_losses)) + 1
        Process.put(:epoch_losses, Process.get(:epoch_losses) ++ [loss_val])
        IO.puts("  #{name} epoch #{epoch_num}/#{epochs} — loss: #{Float.round(loss_val, 4)}")
        {:continue, loop_state}
      end)
      |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: epochs)

    elapsed_ms = System.monotonic_time(:millisecond) - start_time
    losses = Process.get(:epoch_losses)
    Process.delete(:epoch_losses)

    %{state: state, losses: losses, time_s: elapsed_ms / 1000, name: name}
  end

  @doc """
  Evaluate a trained model on test data.

  Returns `{accuracy, loss}` where:
    - `accuracy` is the fraction of correct predictions (0.0 to 1.0)
    - `loss` is the mean cross-entropy loss on the test set
  """
  def evaluate(model, state, test_x, test_y_ids, num_classes) do
    {_init_fn, predict_fn} = Axon.build(model)
    logits = predict_fn.(state, test_x)

    # Accuracy: how often does the model's top prediction match the true class?
    preds = Nx.argmax(logits, axis: -1)
    accuracy = Nx.equal(preds, test_y_ids) |> Nx.mean() |> Nx.to_number()

    # Cross-entropy loss on test set (manual computation)
    targets_onehot =
      test_y_ids
      |> Nx.reshape({:auto, 1})
      |> Nx.equal(Nx.iota({1, num_classes}))
      |> Nx.as_type(:f32)

    loss =
      logits
      |> Axon.Activations.softmax()
      |> Nx.max(1.0e-7)
      |> Nx.log()
      |> Nx.multiply(targets_onehot)
      |> Nx.sum(axes: [-1])
      |> Nx.negate()
      |> Nx.mean()
      |> Nx.to_number()

    {accuracy, loss}
  end
end

IO.puts("Training infrastructure ready.")
IO.puts("All models will use:")
IO.puts("  - Loss: categorical cross-entropy")
IO.puts("  - Optimizer: Adam")
IO.puts("  - Learning rate: 3e-4")

Train & Compare

Now we train all 3 models sequentially with identical settings. Watch the per-epoch loss values as they print — you can already see which normalizer learns faster.

What to look for:

  • Early epochs: Which model drops loss fastest? Fast initial learning suggests the normalization function’s inductive bias matches the task well.
  • Later epochs: Which model reaches the lowest final loss? This shows which function enables the best final representation.
  • Stability: Are the loss values smooth or jagged? Jagged values suggest training instability, which can happen if the normalization function creates very sharp or very flat gradients.
epochs = 15  # More epochs than default to see convergence behavior

models = [
  {"Softmax", softmax_model},
  {"SSMax", ssmax_model},
  {"Softpick", softpick_model}
]

IO.puts("Training #{length(models)} models, #{epochs} epochs each...")
IO.puts("This should take 1-5 minutes depending on your hardware.\n")

results =
  Enum.map(models, fn {label, model} ->
    IO.puts("\n#{String.duplicate("=", 50)}")
    IO.puts("Training #{label}")
    IO.puts(String.duplicate("-", 50))

    result = ShootoutTrainer.train(model, train_data,
      epochs: epochs,
      lr: 3.0e-4,
      name: label
    )

    IO.puts("  Finished in #{Float.round(result.time_s, 1)}s")
    {label, model, result}
  end)

IO.puts("\n#{String.duplicate("=", 50)}")
IO.puts("All #{length(results)} models trained!")

Results — Summary Table

Let’s evaluate all 3 models on the held-out test set and see how they compare.

Understanding the metrics:

  • Test Accuracy: The percentage of test sequences the model classifies correctly. Random guessing on 4 classes = 25%. Anything above that means the model learned something. Higher is better.
  • Test Loss: Cross-entropy loss on the test set. Lower is better. While accuracy is “did you get it right?”, loss is “how confident were you?” — a model that’s right but uncertain has higher loss than one that’s right and confident.
  • Train Time: Wall-clock seconds for all epochs. This reflects both the computational cost of the normalization function and any overhead from learnable parameters (SSMax has an extra s parameter to optimize).
metrics =
  Enum.map(results, fn {label, model, result} ->
    {accuracy, test_loss} =
      ShootoutTrainer.evaluate(model, result.state, test_x, test_y, num_classes)

    %{
      label: label,
      test_accuracy: accuracy,
      test_loss: test_loss,
      train_time_s: result.time_s,
      final_train_loss: List.last(result.losses) || 0.0,
      losses: result.losses
    }
  end)

# Print a formatted comparison table
IO.puts(String.duplicate("=", 68))
IO.puts(
  String.pad_trailing("  Normalizer", 16) <>
  String.pad_trailing("Test Acc", 12) <>
  String.pad_trailing("Test Loss", 12) <>
  String.pad_trailing("Train Loss", 14) <>
  "Train Time"
)
IO.puts("  " <> String.duplicate("-", 62))

Enum.each(metrics, fn m ->
  IO.puts(
    String.pad_trailing("  #{m.label}", 16) <>
    String.pad_trailing("#{Float.round(m.test_accuracy * 100, 1)}%", 12) <>
    String.pad_trailing("#{Float.round(m.test_loss, 4)}", 12) <>
    String.pad_trailing("#{Float.round(m.final_train_loss, 4)}", 14) <>
    "#{Float.round(m.train_time_s, 1)}s"
  )
end)

IO.puts(String.duplicate("=", 68))

# Highlight winners
best_acc = Enum.max_by(metrics, &amp; &amp;1.test_accuracy)
best_loss = Enum.min_by(metrics, &amp; &amp;1.test_loss)
fastest = Enum.min_by(metrics, &amp; &amp;1.train_time_s)

IO.puts("\nBest accuracy:  #{best_acc.label} (#{Float.round(best_acc.test_accuracy * 100, 1)}%)")
IO.puts("Lowest loss:    #{best_loss.label} (#{Float.round(best_loss.test_loss, 4)})")
IO.puts("Fastest:        #{fastest.label} (#{Float.round(fastest.train_time_s, 1)}s)")

Results — Accuracy & Loss Charts

Bar charts make the comparison immediately visual. The accuracy chart shows “who got the most answers right” and the loss chart shows “who was most confident in the right answers.”

acc_data =
  Enum.map(metrics, fn m ->
    %{"Normalizer" => m.label, "Accuracy (%)" => Float.round(m.test_accuracy * 100, 1)}
  end)

loss_data =
  Enum.map(metrics, fn m ->
    %{"Normalizer" => m.label, "Test Loss" => Float.round(m.test_loss, 4)}
  end)

time_data =
  Enum.map(metrics, fn m ->
    %{"Normalizer" => m.label, "Time (s)" => Float.round(m.train_time_s, 1)}
  end)

# Build three individual charts and concatenate them horizontally
metric_charts = [
  Vl.new(width: 220, height: 250, title: "Test Accuracy (higher is better)")
  |> Vl.data_from_values(acc_data)
  |> Vl.mark(:bar)
  |> Vl.encode_field(:x, "Normalizer", type: :nominal)
  |> Vl.encode_field(:y, "Accuracy (%)", type: :quantitative)
  |> Vl.encode_field(:color, "Normalizer", type: :nominal, legend: nil),

  Vl.new(width: 220, height: 250, title: "Test Loss (lower is better)")
  |> Vl.data_from_values(loss_data)
  |> Vl.mark(:bar)
  |> Vl.encode_field(:x, "Normalizer", type: :nominal)
  |> Vl.encode_field(:y, "Test Loss", type: :quantitative)
  |> Vl.encode_field(:color, "Normalizer", type: :nominal, legend: nil),

  Vl.new(width: 220, height: 250, title: "Training Time (seconds)")
  |> Vl.data_from_values(time_data)
  |> Vl.mark(:bar)
  |> Vl.encode_field(:x, "Normalizer", type: :nominal)
  |> Vl.encode_field(:y, "Time (s)", type: :quantitative)
  |> Vl.encode_field(:color, "Normalizer", type: :nominal, legend: nil)
]

Vl.new()
|> Vl.concat(metric_charts, :horizontal)

Loss Curve Comparison

This chart overlays the training loss for all 3 models across epochs. It’s the most informative single visualization because it shows not just the final result but the entire learning trajectory.

What to look for:

  • Steep initial drop: The model quickly learns basic patterns. Which normalizer figures out the task fastest?
  • Lower final loss: Better at the task overall. Which normalizer enables the model to find the signal most effectively?
  • Smooth vs jagged: Smooth curves mean stable training. Jagged curves suggest the gradient signal is noisy, which can happen if the normalization function creates very sharp or very flat attention distributions.
  • Do the curves cross?: If model A starts fast but model B catches up and overtakes, that tells you about the long-run properties of the normalizer vs the short-run inductive bias.
loss_curve_data =
  Enum.flat_map(results, fn {label, _model, result} ->
    result.losses
    |> Enum.with_index(1)
    |> Enum.map(fn {loss, epoch} ->
      %{"Normalizer" => label, "Epoch" => epoch, "Loss" => loss}
    end)
  end)

Vl.new(width: 700, height: 400, title: "Training Loss Curves — Softmax vs SSMax vs Softpick")
|> Vl.data_from_values(loss_curve_data)
|> Vl.mark(:line, point: true, stroke_width: 2)
|> Vl.encode_field(:x, "Epoch", type: :quantitative, title: "Epoch")
|> Vl.encode_field(:y, "Loss", type: :quantitative, title: "Training Loss (Cross-Entropy)")
|> Vl.encode_field(:color, "Normalizer", type: :nominal)
|> Vl.encode_field(:stroke_dash, "Normalizer", type: :nominal)

Key Takeaways

How the functions differ in practice

Softmax is the tried-and-true default. It creates a proper probability distribution (positive weights summing to 1), which has nice theoretical properties. Its weakness is the dilution problem: as sequences get longer, attention becomes more uniform and the model has trouble focusing. For short sequences (like our 32-length examples), this weakness barely shows up.

SSMax (Scalable Softmax) is designed specifically to solve the dilution problem. It subtracts s * log(n) from scores before applying softmax, where s is a learnable parameter and n is the sequence length. When s > 0 (the typical learned value), this sharpens attention for longer sequences. SSMax shines when sequence lengths vary a lot during training, or when you need to scale a model to longer contexts than it was trained on. At short fixed sequence lengths (like this notebook), it may not show a dramatic advantage over standard softmax.

Softpick takes a fundamentally different approach: no exponentials at all, just x / (1 + sum(|x|)). This means:

  • Attention weights can be negative (anti-attention: “this position is actively irrelevant”)
  • Weights don’t sum to 1 — the model can express “nothing here is very relevant”
  • No saturation — gradients don’t vanish for large scores like they do with softmax’s exponential

The tradeoff is that Softpick doesn’t create a probability distribution, which means the model’s attention patterns are harder to interpret and may behave differently during training.

When to use each

Normalizer Best For Avoid When
Softmax General-purpose, well-understood, great tooling support Very long sequences where dilution hurts
SSMax Variable-length sequences, scaling to longer contexts, production transformers that need to generalize across lengths Fixed short sequences (overhead without benefit)
Softpick Experimental/research, tasks where anti-attention helps, gradient flow in very deep networks You need interpretable attention weights or probability guarantees

Important caveat

This is a small-scale comparison (32-length sequences, 64-dim models, 800 training samples). At production scale (thousands of tokens, larger models, millions of examples), the differences — especially SSMax’s advantage on long sequences — become much more pronounced. The sequence length sensitivity chart earlier in the notebook gives a hint at what happens at scale.

Try It Yourself

Here are experiments to deepen your understanding. Each one teaches you something different about how these normalizers behave:

  • Longer sequences: Change seq_len to 64 or 128. Does SSMax’s advantage become more visible? The dilution problem gets worse with longer sequences, so this is where SSMax should shine.

  • More epochs: Try 30 or 50 epochs. Some normalizers may be slow starters that eventually overtake faster ones. Do the loss curves cross?

  • Different hidden sizes: Try hidden_size: 32 (smaller, faster) or hidden_size: 128 (larger, more capacity). Does the relative ranking of normalizers change with model capacity?

  • Number of heads: Try num_heads: 1 (single head) vs num_heads: 8 (many heads). More heads means each head has a smaller dimension — does this affect which normalizer works best?

  • Signal difficulty: Make the signal shorter (signal_len: 2) or longer (signal_len: 5). A shorter signal is harder to find — does this change which normalizer performs best?

  • SSMax initialization: The s parameter in SSMax starts at 1.0 by default. You could modify the SSMax module to try different initial values (0.5, 2.0) to see how the initialization affects training.

  • More classes: Try num_classes: 8 or num_classes: 16. More classes require finer discrimination — does this favor any normalizer?