Powered by AppSignal & Oban Pro

Small Language Model

notebooks/small_language_model.livemd

Small Language Model

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server). Uncomment the EXLA lines for GPU acceleration.

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino_vega_lite, "~> 0.1"},
  {:kino, "~> 0.14"}
])

# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh. See the Architecture Zoo notebook for full setup instructions.

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")

Introduction

Have you ever wondered how ChatGPT, Claude, or other AI chatbots generate text? At their core, they all work the same way: predict the next word (or character) given everything that came before it. That’s it. The magic is in doing this prediction really, really well.

This notebook builds a tiny version of that system from scratch. We’ll create a character-level language model — a program that reads text one letter at a time and learns to predict what letter comes next. It’s like teaching a computer to play a never-ending game of “guess the next letter.”

We’ll use Edifice’s DecoderOnly transformer, which is a simplified version of the same architecture family that powers GPT, LLaMA, and Claude. Don’t worry if those names don’t mean anything yet — by the end of this notebook, you’ll understand what’s inside them.

What you’ll learn:

  • Tokenization — How computers turn text (which they can’t directly understand) into numbers (which they can). We’ll do this at the character level: each letter, space, and punctuation mark gets its own number.
  • The transformer architecture — The type of neural network that revolutionized AI. You’ll build one and see how it processes sequences of characters to make predictions.
  • Training a language model — How the model learns from examples by comparing its predictions to the real answers and gradually improving. We use a technique called cross-entropy loss to measure how wrong the model’s guesses are.
  • Text generation — How to use a trained model to write new text, character by character. You’ll experiment with temperature, a knob that controls whether the model plays it safe or takes creative risks.
  • Architecture swapping — How Edifice lets you swap the entire “brain” of the model (from a transformer to a completely different architecture called Mamba) in a single line of code.

We’ll train on excerpts from two classic novels — Emily Bronte’s Wuthering Heights (1847) and James Joyce’s Ulysses (1922) — both in the public domain. Bronte’s gothic prose and Joyce’s modernist stream-of-consciousness give the model two very different writing styles to learn from.

Text Corpus and Vocabulary

A corpus is just a fancy word for “the collection of text we’re going to learn from.” In real-world AI, the corpus might be billions of web pages. Here, we use a few kilobytes of text pasted directly into the notebook — tiny by AI standards, but enough to see the learning process in action.

The model will learn patterns like: after the letter “t”, the letter “h” is very likely (as in “the”, “that”, “this”). After a period, a space and a capital letter usually follow. These are character-level patterns — the building blocks of language.

# Excerpts from public domain novels

wuthering_heights = """
I have just returned from a visit to my landlord the solitary neighbour
that I shall be troubled with. This is certainly a beautiful country.
In all England I do not believe that I could have fixed on a situation
so completely removed from the stir of society. A perfect misanthropist
heaven and Mr Heathcliff and I are such a suitable pair to divide the
desolation between us. A capital fellow. He little imagined how my
heart warmed towards him when I beheld his black eyes withdraw so
suspiciously under their brows as I rode up and when his fingers
sheltered themselves with a jealous resolution still further in his
waistcoat as I announced my name.

Mr Heathcliff I said. A nod was the answer. Mr Lockwood your new
tenant sir. I do myself the honour of calling as soon as possible
after my arrival to express the hope that I have not inconvenienced
you by my perseverance in soliciting the occupation of Thrushcross
Grange. I heard yesterday you had some thoughts.

Thrushcross Grange is my own sir she interrupted wincing. I should
not allow anyone to inconvenience me if I could hinder it. Walk in.

The walk in was uttered with closed teeth and expressed the sentiment
Go to the deuce. Even the gate over which he leant manifested no
sympathising movement to the words and I think that circumstance
determined me to accept the invitation. I felt interested in a man
who seemed more exaggeratedly reserved than myself.

When he saw my horse breast the swollen beck that lay across our path
he spoke quickly Go round with your beast. There was no other entrance
than to obey. I dismounted and leading my horse made my way towards
the dwelling. It was a strange sight at that hour. The front of the
house showed a solid mass of shadow above the level of the ground.

The narrow windows are deeply set in the wall and the corners defended
with large jutting stones. Before passing the threshold I paused to
admire a quantity of grotesque carving lavished over the front and
especially about the principal door above which among a wilderness of
crumbling griffins and shameless little boys I detected the date 1500
and the name Hareton Earnshaw.

I would have made a few comments and requested a short history of the
place from the surly owner but his attitude at the door appeared to
demand my speedy entrance or complete departure and I had no desire
to aggravate his impatience previous to inspecting the penetralium.
"""

ulysses = """
Stately plump Buck Mulligan came from the stairhead bearing a bowl of
lather on which a mirror and a razor lay crossed. A yellow dressing
gown ungirdled was sustained gently behind him on the mild morning
air. He held the bowl aloft and intoned. He peered sideways up and
gave a long slow whistle of call then paused awhile in rapt attention
his even white teeth glistening here and there with gold points.

Two strong shrill whistles answered through the calm. He turned
abruptly his great searching eyes going up from the stairhead towards
the sea. Come up Kinch. Come up you fearful jesuit. Solemnly he came
forward and mounted the round gunrest. He faced about and blessed
gravely thrice the tower the surrounding land and the awaking
mountains. Then catching sight of the nickel shaving bowl he went
over to it and peered down at the water. He shaved evenly and with
care telling of his plans.

For this you may thank the one who calls herself my mother. She
wants me to wear black. I told her I would. Yes I will serve. A
wandering voice said from the stairfoot. He walked on waiting to be
spoken to following the path that turned along the wall towards the
tower but she was waiting down in the sitting room.

The sea was calm. A great sweet mother. The sea fresh and bright
and beautiful and great sweet mother. He turned away from the sea.
Woodshadows floated silently by through the morning peace from the
stairhead seaward where he gazed. Silently moving in the calm water
bearing its own green body above the grey stones gently rising past
everything softly at last the shadows fell across the water.

He could hear them in their room. He crossed to the bright side of
the room. His body seemed to grow taller to reach the height of the
doorway. A moment before he turned away walking down the path. He
looked back across the water towards the tower. The deep blue sky
stretched over the fields and the sea. Morning peace lay on everything.
"""

corpus = wuthering_heights <> "\n" <> ulysses

IO.puts("Corpus size: #{String.length(corpus)} characters")
IO.puts("Preview (first 200 chars):")
IO.puts(String.slice(corpus, 0, 200) <> "...")

Now let’s build the vocabulary. Neural networks can only work with numbers, not letters, so we need a way to convert between the two. We assign every unique character in our text a number (its ID). For example, "a" might become 0, "b" becomes 1, a space becomes 2, and so on. We also build the reverse mapping so we can turn the model’s numeric predictions back into readable text.

# Build character vocabulary from the corpus
chars = corpus |> String.graphemes() |> Enum.uniq() |> Enum.sort()
vocab_size = length(chars)

# Create bidirectional mappings
char_to_id = chars |> Enum.with_index() |> Map.new()
id_to_char = chars |> Enum.with_index() |> Map.new(fn {ch, i} -> {i, ch} end)

IO.puts("Vocabulary size: #{vocab_size} unique characters")
IO.puts("Characters: #{inspect(Enum.join(chars, ""))}")

Let’s visualize how often each character appears. This is called a frequency distribution — it shows which characters the model will see most often during training. Spaces and common letters like “e”, “t”, “a” will dominate. Rare characters are harder for the model to learn because it sees fewer examples.

# Count character frequencies
freq_data =
  corpus
  |> String.graphemes()
  |> Enum.frequencies()
  |> Enum.map(fn {char, count} ->
    label = case char do
      "\n" -> "\\n"
      " " -> "SPC"
      ch -> ch
    end
    %{"char" => label, "count" => count}
  end)
  |> Enum.sort_by(&amp; &amp;1["count"], :desc)

Vl.new(width: 700, height: 300, title: "Character Frequency Distribution")
|> Vl.data_from_values(freq_data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "char", type: :nominal, sort: "-y", title: "Character")
|> Vl.encode_field(:y, "count", type: :quantitative, title: "Count")
|> Vl.encode_field(:color, "count", type: :quantitative, scale: %{scheme: "blues"}, legend: nil)

Data Preparation

Now we need to turn our raw text into training examples the model can learn from. The idea is simple: show the model a sequence of characters, and ask it to predict what comes next.

Imagine sliding a window across the text. If our window is 32 characters wide and we’re looking at "the stir of society. A perfect mi", the model’s job is to predict the next character: "s" (from “misanthropist”). We slide the window forward by one character and repeat — thousands of times across the whole text. Each slide creates one training example.

One-hot encoding is how we represent each character as numbers the model can process. You might wonder: we already gave each character a numeric ID (like a=0, b=1, c=2) — why not just feed those numbers to the model?

The problem is that plain numbers imply a relationship that doesn’t exist. If a=0, b=1, and z=25, the model would think “a” and “b” are very similar (only 1 apart) while “a” and “z” are very different (25 apart). But in language, “a” and “z” aren’t any more or less related than “a” and “b” — they’re just different characters.

One-hot encoding fixes this. Instead of a single number, each character becomes a vector (a list of numbers) that’s all zeros except for a single 1 at the position corresponding to that character’s ID. Here’s what it looks like with a tiny 5-character vocabulary:

Vocabulary: [a, b, c, d, e]

"a" → [1, 0, 0, 0, 0]    ← the 1 is at position 0
"b" → [0, 1, 0, 0, 0]    ← the 1 is at position 1
"c" → [0, 0, 1, 0, 0]    ← the 1 is at position 2
"e" → [0, 0, 0, 0, 1]    ← the 1 is at position 4

Now every character is exactly the same “distance” from every other character — they’re all equally different. The model starts with no assumptions about which characters are related and learns those relationships during training (for example, it might learn that vowels behave similarly to each other).

The name “one-hot” comes from hardware terminology — exactly one element is “hot” (set to 1), everything else is “cold” (set to 0).

In our case, the vocabulary has ~50 characters, so each character becomes a vector of ~50 numbers. Yes, this is much larger than a single ID number, but it gives the model a much better starting point for learning.

The final input shape is {batch, seq_len, vocab_size} — that’s a 3D block of numbers where batch is how many examples we process at once, seq_len is the window width, and vocab_size is the one-hot vector length.

seq_len = 32

# Convert entire corpus to integer IDs
corpus_ids =
  corpus
  |> String.graphemes()
  |> Enum.map(&amp;Map.fetch!(char_to_id, &amp;1))

corpus_len = length(corpus_ids)
n_windows = corpus_len - seq_len

IO.puts("Sequence length: #{seq_len}")
IO.puts("Total windows: #{n_windows}")

# Build input/target pairs using vectorized indexing (no per-window loops)
IO.puts("Building sliding windows...")
corpus_tensor = Nx.tensor(corpus_ids, type: :s32)

# Create all window indices at once: each row is [i, i+1, ..., i+seq_len-1]
# This is a single Nx operation instead of ~2900 individual Nx.slice calls
window_offsets = Nx.iota({1, seq_len})
window_starts = Nx.iota({n_windows, 1})
all_indices = Nx.add(window_starts, window_offsets)

x_ids = Nx.take(corpus_tensor, Nx.reshape(all_indices, {:auto})) |> Nx.reshape({n_windows, seq_len})

# Targets: the character right after each window
y_ids = Nx.slice(corpus_tensor, [seq_len], [n_windows])

IO.puts("x_ids shape: #{inspect(Nx.shape(x_ids))}  (windows x seq_len)")
IO.puts("y_ids shape: #{inspect(Nx.shape(y_ids))}  (windows,)")

# One-hot encode inputs: {n_windows, seq_len} -> {n_windows, seq_len, vocab_size}
IO.puts("\nOne-hot encoding inputs (vocab_size=#{vocab_size})...")

x_onehot =
  x_ids
  |> Nx.reshape({n_windows * seq_len, 1})
  |> Nx.equal(Nx.iota({1, vocab_size}))
  |> Nx.as_type(:f32)
  |> Nx.reshape({n_windows, seq_len, vocab_size})

IO.puts("x_onehot shape: #{inspect(Nx.shape(x_onehot))}  (batch, seq_len, vocab_size)")
IO.puts("y_ids shape: #{inspect(Nx.shape(y_ids))}  (batch,) — integer class labels")

Now we split the data into training and test sets. This is a fundamental practice in machine learning: we train the model on 90% of the data, then check how well it does on the remaining 10% it has never seen. If the model performs well on the test set, it has learned general patterns rather than just memorizing the training data.

We also group examples into batches of 64. Instead of showing the model one example at a time (slow), we show it 64 examples simultaneously (fast). Modern hardware like GPUs is designed to process many examples in parallel.

# 90/10 train/test split
n_train = round(n_windows * 0.9)
batch_size = 64

# One-hot encode targets for cross-entropy loss: {n_windows} -> {n_windows, vocab_size}
y_onehot =
  y_ids
  |> Nx.reshape({n_windows, 1})
  |> Nx.equal(Nx.iota({1, vocab_size}))
  |> Nx.as_type(:f32)

train_x = x_onehot[0..(n_train - 1)]
train_y = y_onehot[0..(n_train - 1)]
test_x = x_onehot[n_train..-1//1]
test_y = y_onehot[n_train..-1//1]
test_y_ids = y_ids[n_train..-1//1]

# Batch training data
train_data =
  Enum.zip(
    Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
    Nx.to_batched(train_y, batch_size) |> Enum.to_list()
  )

n_test = n_windows - n_train

IO.puts("Train: #{n_train} windows, #{length(train_data)} batches")
IO.puts("Test:  #{n_test} windows")
IO.puts("Batch size: #{batch_size}")

Build the Transformer

Time to build the actual neural network. We use a decoder-only transformer, which is the architecture behind GPT, LLaMA, and most modern language models. Let’s break down what that means:

A transformer is a type of neural network specifically designed to process sequences (like text). Its key innovation is attention — the ability to look at all positions in the input simultaneously and decide which parts are most relevant to the current prediction. Unlike older approaches that read text strictly left-to-right, attention lets the model connect distant parts of the input (e.g., linking a pronoun back to the noun it refers to).

“Decoder-only” means the model is designed for generation: given a sequence, predict what comes next. (The alternative, “encoder-decoder”, is used for tasks like translation where you have a full input and need to produce a full output.)

Inside our transformer, several modern techniques work together:

  • Grouped Query Attention (GQA) — A more memory-efficient version of the standard attention mechanism. Instead of giving every “reader” (attention head) its own copy of the input summary, groups of readers share summaries.
  • Rotary Position Embeddings (RoPE) — Tells the model where each character is in the sequence. Without this, the model couldn’t tell “abc” from “cab”.
  • SwiGLU — The type of “thinking layer” between attention steps. It’s the part where the model actually transforms what it learned from attention into useful representations. (SwiGLU is a specific activation function that works better than simpler alternatives like ReLU.)
  • RMSNorm — Keeps the numbers flowing through the network in a reasonable range, preventing the model from becoming numerically unstable during training.

Our model is intentionally tiny so it trains in minutes, not hours:

  • hidden_size: 64 — The model thinks in 64-dimensional space (GPT-3 uses 12,288)
  • num_layers: 2 — Two transformer blocks stacked (GPT-3 has 96)
  • num_heads: 4 with num_kv_heads: 2 — 4 attention readers sharing 2 key/value pairs
  • embed_dim: vocab_size — The one-hot character vectors are projected down to hidden_size internally
transformer_model =
  Edifice.build(:decoder_only,
    embed_dim: vocab_size,
    hidden_size: 64,
    num_heads: 4,
    num_kv_heads: 2,
    num_layers: 2,
    window_size: seq_len,
    dropout: 0.05
  )
  |> Axon.dense(vocab_size, name: "lm_head")

IO.puts("Model built: DecoderOnly -> dense(#{vocab_size})")
IO.puts("Input:  {batch, #{seq_len}, #{vocab_size}}")
IO.puts("Output: {batch, #{vocab_size}} (logits over vocabulary)")

Training

Training is how the model learns. Here’s the cycle, repeated thousands of times:

  1. Forward pass: Feed a batch of character sequences into the model. It outputs a probability distribution over the vocabulary — essentially saying “I think there’s a 30% chance the next character is ‘e’, 15% chance it’s ‘t’, 2% chance it’s ‘z’…” for each example in the batch.

  2. Compute loss: Compare the model’s predictions to the actual answers using cross-entropy loss. This is a way of measuring “how surprised was the model by the correct answer?” If the model gave 90% probability to the right character, the loss is low. If it only gave 1%, the loss is high. Lower loss = better predictions.

  3. Backward pass: Calculate how each weight (number) in the model contributed to the error. This uses calculus (specifically, the chain rule) and is called backpropagation.

  4. Update weights: Nudge each weight slightly in the direction that would reduce the loss. The Adam optimizer is the algorithm that decides exactly how much to nudge — it’s smarter than a simple “move a fixed amount” because it adapts the step size based on each weight’s history.

One full pass through all the training data is called an epoch. We train for 10 epochs, meaning the model sees every example 10 times. With each epoch, it should get a little better at predicting the next character.

defmodule LMTrainer do
  @moduledoc "Helper for training and evaluating language models."

  def train(model, train_data, opts \\ []) do
    epochs = Keyword.get(opts, :epochs, 20)
    lr = Keyword.get(opts, :lr, 3.0e-4)

    loss_fn = &amp;Axon.Losses.categorical_cross_entropy(&amp;1, &amp;2, from_logits: true, reduction: :mean)

    model
    |> Axon.Loop.trainer(
      loss_fn,
      Polaris.Optimizers.adam(learning_rate: lr),
      log: 1
    )
    |> Axon.Loop.metric(:accuracy)
    |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: epochs)
  end

  def evaluate(model, state, test_x, test_y, vocab_size) do
    {_init_fn, predict_fn} = Axon.build(model)
    logits = predict_fn.(state, test_x)

    # Accuracy
    preds = Nx.argmax(logits, axis: -1)
    correct = Nx.equal(preds, test_y) |> Nx.mean() |> Nx.to_number()

    # Cross-entropy loss
    targets_onehot =
      test_y
      |> Nx.reshape({:auto, 1})
      |> Nx.equal(Nx.iota({1, vocab_size}))
      |> Nx.as_type(:f32)

    loss =
      logits
      |> Axon.Activations.softmax()
      |> Nx.max(1.0e-7)
      |> Nx.log()
      |> Nx.multiply(targets_onehot)
      |> Nx.sum(axes: [-1])
      |> Nx.negate()
      |> Nx.mean()
      |> Nx.to_number()

    {loss, correct}
  end
end
IO.puts("Training transformer (10 epochs)...")
IO.puts("  ~2-3 min/epoch on CPU, ~20s/epoch on GPU. Grab a coffee if on CPU.\n")

transformer_state = LMTrainer.train(transformer_model, train_data, epochs: 10, lr: 3.0e-4)

{transformer_loss, transformer_acc} =
  LMTrainer.evaluate(transformer_model, transformer_state, test_x, test_y_ids, vocab_size)

IO.puts("\n--- Transformer Results ---")
IO.puts("Test loss:     #{Float.round(transformer_loss, 4)}")
IO.puts("Test accuracy: #{Float.round(transformer_acc * 100, 1)}%")

Text Generation

Now the fun part — making the model write! The process is called autoregressive generation, and it works like this:

  1. Start with a seed — a short piece of text like "the morning ".
  2. Feed the seed into the model. It outputs probabilities for every character in the vocabulary (e.g., “60% chance the next character is ‘p’, 10% ‘l’…”).
  3. Pick a character from those probabilities (more on how in a moment).
  4. Append that character to the text.
  5. Slide the window forward by one and repeat from step 2.

Each new character is influenced by all the characters before it (up to the window size). The model is literally writing one letter at a time, just like autocomplete on your phone but at the character level.

Temperature is the knob that controls how the model picks from its probability distribution. Think of it like a confidence dial:

  • Temperature = 0 (greedy): Always pick the single most likely character. This is the most predictable — the model always makes its “safest” guess. It often leads to repetitive loops because the model keeps making the same safe choices.
  • Temperature = 0.5 (conservative): Mostly pick likely characters, but occasionally take a small risk. More varied than greedy, still fairly coherent.
  • Temperature = 0.8 (balanced): A good middle ground. The model is willing to explore less obvious choices, leading to more interesting (but sometimes messier) output.
  • Temperature = 1.2 (creative): The model spreads its bets more evenly across characters. You get surprising combinations, invented word-like structures, and more chaos. Fun to read, but less “correct.”

Mathematically, temperature divides the model’s raw output scores (called logits) before converting them to probabilities. Dividing by a small number makes the highest-scoring option even more dominant; dividing by a large number flattens the differences out.

defmodule TextGenerator do
  @moduledoc "Autoregressive character-level text generation."

  @doc "Generate `n_chars` of text given a seed string."
  def generate(model, state, seed, n_chars, opts \\ []) do
    temp = Keyword.get(opts, :temperature, 1.0)
    char_to_id = Keyword.fetch!(opts, :char_to_id)
    id_to_char = Keyword.fetch!(opts, :id_to_char)
    vocab_size = Keyword.fetch!(opts, :vocab_size)
    seq_len = Keyword.fetch!(opts, :seq_len)
    key = Keyword.get(opts, :key, Nx.Random.key(42))

    {_init_fn, predict_fn} = Axon.build(model)

    # Convert seed to character IDs
    seed_ids =
      seed
      |> String.graphemes()
      |> Enum.map(&amp;Map.fetch!(char_to_id, &amp;1))

    # Pad or truncate seed to seq_len
    seed_ids =
      if length(seed_ids) >= seq_len do
        Enum.take(seed_ids, -seq_len)
      else
        # Pad with space character on the left
        space_id = Map.fetch!(char_to_id, " ")
        List.duplicate(space_id, seq_len - length(seed_ids)) ++ seed_ids
      end

    # Generate characters one at a time
    {generated_ids, _key} =
      Enum.reduce(1..n_chars, {seed_ids, key}, fn _i, {current_ids, rng} ->
        # Take last seq_len characters
        window = Enum.take(current_ids, -seq_len)

        # One-hot encode: {1, seq_len, vocab_size}
        input =
          window
          |> Nx.tensor(type: :s32)
          |> Nx.reshape({seq_len, 1})
          |> Nx.equal(Nx.iota({1, vocab_size}))
          |> Nx.as_type(:f32)
          |> Nx.reshape({1, seq_len, vocab_size})

        # Get logits and apply temperature
        logits = predict_fn.(state, input) |> Nx.reshape({vocab_size})

        {next_id, rng} =
          if temp <= 0.01 do
            # Greedy (argmax)
            {Nx.argmax(logits) |> Nx.to_number(), rng}
          else
            # Temperature sampling
            scaled = Nx.divide(logits, temp)
            probs = Axon.Activations.softmax(scaled)
            {sample, rng} = Nx.Random.choice(rng, Nx.iota({vocab_size}), probs, samples: 1)
            {Nx.to_number(sample[0]), rng}
          end

        {current_ids ++ [next_id], rng}
      end)

    # Convert back to text (skip the seed portion)
    generated_ids
    |> Enum.drop(length(seed_ids |> Enum.take(-seq_len)))
    |> Enum.map(&amp;Map.get(id_to_char, &amp;1, "?"))
    |> Enum.join()
  end
end

Let’s generate text at different temperatures and see the difference in practice. Each character requires a full forward pass through the entire model, so generating 100 characters at 4 different temperatures takes a minute or two on CPU. Watch how the output changes from repetitive (low temperature) to chaotic (high temperature).

seed = "the morning "
gen_opts = [
  char_to_id: char_to_id,
  id_to_char: id_to_char,
  vocab_size: vocab_size,
  seq_len: seq_len
]

IO.puts("Seed: \"#{seed}\"\n")
IO.puts(String.duplicate("=", 60))

for {temp, label} <- [{0.0, "Greedy (T=0)"}, {0.5, "Conservative (T=0.5)"}, {0.8, "Balanced (T=0.8)"}, {1.2, "Creative (T=1.2)"}] do
  text = TextGenerator.generate(
    transformer_model, transformer_state, seed, 100,
    [{:temperature, temp}, {:key, Nx.Random.key(round(temp * 100))} | gen_opts]
  )

  IO.puts("\n#{label}:")
  IO.puts("  #{seed}#{text}")
  IO.puts(String.duplicate("-", 60))
end

:ok

Architecture Swap: Mamba

So far we’ve used a transformer — but it’s not the only game in town. Mamba is a completely different type of neural network called a state-space model (SSM). While transformers use attention (looking at all positions at once), Mamba processes the sequence more like reading a book page by page — maintaining a compressed “memory” (called state) that it updates as it reads each new character.

The tradeoff: transformers are better at connecting distant parts of a sequence (because attention can look everywhere at once), but Mamba is more efficient on very long sequences (because its memory is a fixed size no matter how long the input gets).

One of Edifice’s strengths is making architecture swaps trivial. We change exactly one thing — :decoder_only becomes :mamba — and keep everything else identical: same data, same training loop, same evaluation. This lets us do a fair comparison.

mamba_model =
  Edifice.build(:mamba,
    embed_dim: vocab_size,
    hidden_size: 64,
    state_size: 16,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.05
  )
  |> Axon.dense(vocab_size, name: "lm_head_mamba")

IO.puts("Mamba model built — same input/output shapes as transformer")
IO.puts("Training Mamba (5 epochs)...")
IO.puts("  Mamba's sequential scan is slower than attention on small GPUs.\n")

mamba_state = LMTrainer.train(mamba_model, train_data, epochs: 5, lr: 3.0e-4)

{mamba_loss, mamba_acc} =
  LMTrainer.evaluate(mamba_model, mamba_state, test_x, test_y_ids, vocab_size)

IO.puts("\n--- Mamba Results ---")
IO.puts("Test loss:     #{Float.round(mamba_loss, 4)}")
IO.puts("Test accuracy: #{Float.round(mamba_acc * 100, 1)}%")
# Compare the two architectures
IO.puts(String.duplicate("=", 50))
IO.puts("  Architecture    Test Loss   Accuracy")
IO.puts("  " <> String.duplicate("-", 40))
IO.puts("  Transformer     #{Float.round(transformer_loss, 4) |> to_string() |> String.pad_trailing(11)} #{Float.round(transformer_acc * 100, 1)}%")
IO.puts("  Mamba           #{Float.round(mamba_loss, 4) |> to_string() |> String.pad_trailing(11)} #{Float.round(mamba_acc * 100, 1)}%")
IO.puts(String.duplicate("=", 50))

Let’s see what Mamba generates and compare it side-by-side with the transformer. Do you notice any differences in style or quality? (Another minute or two on CPU.)

IO.puts("Seed: \"#{seed}\"\n")
IO.puts(String.duplicate("=", 60))

for {temp, label} <- [{0.0, "Greedy (T=0)"}, {0.8, "Balanced (T=0.8)"}] do
  IO.puts("\n#{label}:")

  transformer_text = TextGenerator.generate(
    transformer_model, transformer_state, seed, 100,
    [{:temperature, temp}, {:key, Nx.Random.key(7)} | gen_opts]
  )
  IO.puts("  Transformer: #{seed}#{transformer_text}")

  mamba_text = TextGenerator.generate(
    mamba_model, mamba_state, seed, 100,
    [{:temperature, temp}, {:key, Nx.Random.key(7)} | gen_opts]
  )
  IO.puts("  Mamba:       #{seed}#{mamba_text}")

  IO.puts(String.duplicate("-", 60))
end

:ok

What’s Next?

Why does the generated text look like gibberish?

Don’t be disappointed if the output looks like nonsense! This is completely expected, and understanding why is just as educational as the model itself.

Our model is learning character-level patterns — things like “after ‘t’, ‘h’ is common” and “spaces usually come after periods.” These are real statistical patterns in English, and if you look carefully at the output, you’ll likely see fragments that look almost like real words. That’s the model learning!

But it can’t learn grammar, meaning, or even complete words reliably. Here’s why, ranked by importance:

  1. Data size (biggest factor): Our corpus is ~4KB — roughly one page of text. That’s like trying to learn English from a single paragraph. The model can only learn patterns it sees enough examples of. Real language models train on gigabytes or terabytes of text. Even 1-2MB of text (a single novel) would produce noticeably more English-like output with this same architecture.

  2. Tokenization granularity: We’re predicting one character at a time. The word “the” requires three correct predictions in a row. Real models use subword tokenizers (like BPE or SentencePiece) that compress common words into single tokens — so “the” is one prediction, not three. This gives the model much more effective context per position.

  3. Context window: Our window is 32 characters — about 5-6 words. The model literally cannot “see” anything further back than that. It’s like trying to write a sentence when you can only remember the last five words. Production models use context windows of 512 to 128,000+ tokens.

  4. Model size: Our model has roughly 50,000 learnable parameters (the numbers it adjusts during training). GPT-2 has 117 million. GPT-4 is rumored to have over a trillion. But here’s an important insight: a bigger model without more data just memorizes faster — it doesn’t truly learn better. Data and model size need to grow together.

The point of this notebook

This notebook teaches the mechanics — how text becomes numbers, how a transformer processes sequences, how training works, and how generation happens. These are the exact same steps used in production AI systems, just at a much smaller scale. Everything here scales up: more data, bigger model, better tokenizer, longer context, and you get increasingly coherent text all the way up to ChatGPT-level fluency.

Ideas for going further

  • More data: Load a full book with File.read!/1 — even one novel (~500KB) makes a dramatic difference. Project Gutenberg has thousands of free books.
  • Better tokenization: Move from characters to subword tokens (BPE or SentencePiece). This is the single biggest quality improvement for the effort involved.
  • Longer context: Try seq_len: 128 or 256 to capture sentence-level structure and longer-range dependencies.
  • Bigger model: hidden_size: 128, num_layers: 4 — but only with more data to match. More parameters + same tiny corpus = memorization.
  • Try more architectures: Replace :mamba with :retnet, :rwkv, :hyena, :xlstm, or any of Edifice’s 30+ sequence models — the training loop stays identical.
  • LoRA fine-tuning: See the planned LoRA notebook for how to adapt a pretrained model to a new task by only training a tiny fraction of its weights — a technique that makes modern AI practical on consumer hardware.