Powered by AppSignal & Oban Pro

Whisper ASR Demo

notebooks/whisper_asr_demo.livemd

Whisper ASR Demo

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server).

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino, "~> 0.14"}
])

# Nx.global_default_backend(EXLA.Backend)

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh.

Nx.global_default_backend(EXLA.Backend)
IO.puts("Attached mode — using EXLA backend from project node")

Introduction

Whisper (Radford et al., 2022) is OpenAI’s speech recognition model. It’s an encoder-decoder transformer trained on 680k hours of multilingual audio. The encoder processes a mel spectrogram of the audio; the decoder auto-regressively generates text tokens.

This notebook walks through the full pipeline using Edifice’s Whisper implementation and pretrained weight loading from HuggingFace.

What you’ll learn

  • How encoder-decoder architectures work (encoder = understand audio, decoder = generate text)
  • How to load pretrained weights from HuggingFace in one line
  • How greedy auto-regressive decoding works (generate one token at a time)
  • The shapes at every stage of the pipeline

T400 note

We use whisper-tiny (~39M params, ~150MB in f32) which fits comfortably in 2GB VRAM. EXLA is strongly recommended — the decoder’s output projection to 51,865 vocab tokens is a large matmul that’s very slow on BinaryBackend but instant on GPU. Use the “Attached to project” setup cell.

# Whisper-tiny dimensions
config = %{
  name: "whisper-tiny",
  hidden_dim: 384,
  encoder_layers: 4,
  decoder_layers: 4,
  num_heads: 6,
  ffn_dim: 1536,
  n_mels: 80,
  vocab_size: 51_865,
  max_audio_len: 1500,   # 30 seconds of audio at 16kHz
  max_dec_len: 448
}

IO.puts("Model: #{config.name}")
IO.puts("  Encoder: #{config.encoder_layers} layers, #{config.hidden_dim} hidden, #{config.num_heads} heads")
IO.puts("  Decoder: #{config.decoder_layers} layers, #{config.hidden_dim} hidden, #{config.num_heads} heads")
IO.puts("  Vocab:   #{config.vocab_size} tokens (BPE + special tokens)")
IO.puts("  Input:   mel spectrogram [batch, #{config.n_mels}, time]")
IO.puts("  Max audio: #{config.max_audio_len} frames = 30 seconds")

1. Architecture Tour

Whisper has two parts: an encoder that reads audio features, and a decoder that generates text tokens. Let’s build both and inspect their shapes.

What to look for

  • The encoder’s conv stem halves the time dimension (stride=2): 200 frames in → 100 frames out
  • The decoder takes integer token IDs (not floats!) and the encoder’s output
  • The decoder outputs logits over the full vocabulary (51,865 tokens)
alias Edifice.Audio.Whisper

# Build encoder and decoder separately
{encoder, decoder} = Whisper.build(
  hidden_dim: config.hidden_dim,
  encoder_layers: config.encoder_layers,
  decoder_layers: config.decoder_layers,
  num_heads: config.num_heads,
  ffn_dim: config.ffn_dim,
  n_mels: config.n_mels,
  vocab_size: config.vocab_size,
  max_audio_len: config.max_audio_len,
  max_dec_len: config.max_dec_len,
  dropout: 0.0
)

IO.puts("Encoder and decoder built successfully.")
IO.puts("(These are Axon computation graphs — no weights allocated yet.)")
# === Encoder: mel spectrogram → audio features ===
#
# The encoder processes a mel spectrogram — a 2D representation of audio
# where one axis is frequency (80 mel bins) and the other is time.
# Two 1D convolutions (with stride 2) downsample the time axis by 2x,
# then transformer blocks process the sequence.

batch = 1
audio_frames = 200  # ~1.3 seconds of audio (shorter for fast demo)

{enc_init, enc_predict} = Axon.build(encoder, mode: :inference)
enc_params = enc_init.(
  %{"mel_spectrogram" => Nx.template({batch, config.n_mels, audio_frames}, :f32)},
  Axon.ModelState.empty()
)

# Count encoder parameters
enc_param_count =
  enc_params
  |> Axon.ModelState.trainable_parameters()
  |> then(fn params ->
    params
    |> Map.values()
    |> Enum.flat_map(fn
      %Nx.Tensor{} = t -> [Nx.size(t)]
      map when is_map(map) -> Enum.map(Map.values(map), &Nx.size/1)
    end)
    |> Enum.sum()
  end)

# Run encoder on synthetic mel spectrogram
mel = Nx.broadcast(0.0, {batch, config.n_mels, audio_frames})
encoder_output = enc_predict.(enc_params, %{"mel_spectrogram" => mel})

IO.puts("=== Encoder ===")
IO.puts("  Input:  mel_spectrogram #{inspect({batch, config.n_mels, audio_frames})}")
IO.puts("  Output: encoder_output  #{inspect(Nx.shape(encoder_output))}")
IO.puts("  Params: ~#{div(enc_param_count, 1000)}K")
IO.puts("")
IO.puts("  The time dimension halved: #{audio_frames}#{elem(Nx.shape(encoder_output), 1)}")
IO.puts("  (Two conv layers with stride 2 → 2x downsampling)")
# === Decoder: encoder output + token IDs → next-token logits ===
#
# The decoder is a causal transformer that generates text one token at a time.
# At each step it takes:
#   - Previously generated token IDs (integer tensor)
#   - Encoder output (from the mel spectrogram)
# And produces logits over the vocabulary for the next token.

dec_len = 4  # short sequence for demo
enc_len = elem(Nx.shape(encoder_output), 1)

{dec_init, dec_predict} = Axon.build(decoder, mode: :inference)
dec_params = dec_init.(
  %{
    "token_ids" => Nx.template({batch, dec_len}, :s64),
    "encoder_output" => Nx.template({batch, enc_len, config.hidden_dim}, :f32)
  },
  Axon.ModelState.empty()
)

# Whisper special tokens:
#   50258 = <|startoftranscript|>
#   50259 = <|en|> (English)
#   50360 = <|transcribe|>
#   50364 = <|notimestamps|>
#   50257 = <|endoftext|>
prompt_tokens = Nx.tensor([[50258, 50259, 50360, 50364]], type: :s64)

logits = dec_predict.(dec_params, %{
  "token_ids" => prompt_tokens,
  "encoder_output" => encoder_output
})

IO.puts("=== Decoder ===")
IO.puts("  Inputs:")
IO.puts("    token_ids:      #{inspect(Nx.shape(prompt_tokens))} (s64 integers)")
IO.puts("    encoder_output: #{inspect(Nx.shape(encoder_output))}")
IO.puts("  Output:")
IO.puts("    logits:         #{inspect(Nx.shape(logits))} (one score per vocab token)")
IO.puts("")
IO.puts("  The last position's logits predict the NEXT token.")
IO.puts("  argmax of logits[:, -1, :] gives the most likely next token ID.")

# What token would the model predict next? (with random weights, this is meaningless)
next_token = logits
  |> Nx.slice_along_axis(dec_len - 1, 1, axis: 1)
  |> Nx.squeeze(axes: [1])
  |> Nx.argmax(axis: 1)
  |> Nx.squeeze()
  |> Nx.to_number()

IO.puts("  Predicted next token: #{next_token} (random weights — will be meaningful after loading pretrained)")

2. Loading Pretrained Weights

Now let’s load real weights from HuggingFace. Pretrained.from_hub/1 handles everything: downloads the checkpoint (cached locally), reads the config, builds the model, and maps weights into Axon parameters.

What to look for

  • First run downloads ~150MB (whisper-tiny). Subsequent runs use the cache.
  • The loaded model state replaces random weights with trained ones.
  • After loading, the encoder’s output on silent audio (all zeros) should be nearly constant — the model learned that silence = no speech.
# This downloads whisper-tiny from HuggingFace (~150MB, cached after first run).
# If you don't have internet access, skip this cell — the architecture tour
# above works fine with random weights.

IO.puts("Downloading whisper-tiny from HuggingFace (cached after first run)...")

{encoder, decoder, model_state} =
  Edifice.Pretrained.from_hub("openai/whisper-tiny",
    strict: false,  # skip unmapped keys (proj_out, learned encoder PE)
    dtype: :f32     # keep f32 for T400 compatibility
  )

IO.puts("Loaded! Model state has #{map_size(model_state.data)} top-level parameter groups.")

3. Encoder on Pretrained Weights

Let’s run the pretrained encoder on synthetic mel spectrograms and see how the trained model behaves differently from random weights.

What to look for

  • Silent audio (all zeros) should produce low-magnitude encoder features
  • Random noise mel should produce higher-magnitude features (the model detects “something”)
  • The encoder output is what the decoder cross-attends to — it’s the model’s understanding of the audio content
# Build encoder prediction function
{_init, enc_predict} = Axon.build(encoder, mode: :inference)

# Test 1: Silent audio (mel = 0)
silent_mel = Nx.broadcast(0.0, {1, config.n_mels, audio_frames})
silent_out = enc_predict.(model_state, %{"mel_spectrogram" => silent_mel})

# Test 2: Random noise mel
key = Nx.Random.key(42)
{noise_mel, _} = Nx.Random.normal(key, 0.0, 1.0,
  shape: {1, config.n_mels, audio_frames}, type: {:f, 32})
noise_out = enc_predict.(model_state, %{"mel_spectrogram" => noise_mel})

silent_norm = Nx.mean(Nx.abs(silent_out)) |> Nx.to_number() |> Float.round(4)
noise_norm = Nx.mean(Nx.abs(noise_out)) |> Nx.to_number() |> Float.round(4)

IO.puts("=== Pretrained Encoder Output ===")
IO.puts("  Silent mel → mean |output|: #{silent_norm}")
IO.puts("  Noise mel  → mean |output|: #{noise_norm}")
IO.puts("")
IO.puts("  The pretrained encoder produces different representations")
IO.puts("  for silence vs noise — it learned to distinguish them.")

# How different are the two outputs?
diff = Nx.subtract(silent_out, noise_out) |> Nx.abs() |> Nx.reduce_max() |> Nx.to_number() |> Float.round(4)
IO.puts("  Max difference: #{diff}")

4. Greedy Decoding Loop

Auto-regressive decoding generates text one token at a time. At each step:

  1. Feed all tokens generated so far to the decoder
  2. Take the argmax of the last position’s logits → next token
  3. Append that token and repeat

We’ll run this with the pretrained decoder on our synthetic encoder output. Since the input audio is silence/noise (not real speech), the model will likely predict <|endoftext|> quickly or hallucinate — that’s expected.

What to look for

  • The decode loop starts with Whisper’s special prompt tokens
  • Each iteration adds one token
  • The loop stops when it hits <|endoftext|> (token 50257) or max length
  • With real audio + real mel extraction, this would produce actual transcriptions
{_init, dec_predict} = Axon.build(decoder, mode: :inference)

# Whisper's standard prompt for English transcription
initial_tokens = [50258, 50259, 50360, 50364]

# Special token names for display
token_names = %{
  50257 => "<|endoftext|>",
  50258 => "<|startoftranscript|>",
  50259 => "<|en|>",
  50360 => "<|transcribe|>",
  50364 => "<|notimestamps|>"
}

# Use the silent audio encoder output
encoder_out = enc_predict.(model_state, %{
  "mel_spectrogram" => Nx.broadcast(0.0, {1, config.n_mels, audio_frames})
})

max_new_tokens = 10
eos_token = 50257

IO.puts("=== Greedy Decoding (silent audio) ===\n")
IO.puts("Prompt: #{Enum.map_join(initial_tokens, " ", &amp;(token_names[&amp;1] || to_string(&amp;1)))}\n")

# Decode loop
{final_tokens, _} =
  Enum.reduce_while(1..max_new_tokens, {initial_tokens, encoder_out}, fn step, {tokens, enc_out} ->
    # Build token tensor from all tokens so far
    token_tensor = Nx.tensor([tokens], type: :s64)

    # Forward pass through decoder
    logits = dec_predict.(model_state, %{
      "token_ids" => token_tensor,
      "encoder_output" => enc_out
    })

    # Argmax at last position → next token
    seq_len = length(tokens)
    next_id =
      logits
      |> Nx.slice_along_axis(seq_len - 1, 1, axis: 1)
      |> Nx.squeeze(axes: [1])
      |> Nx.argmax(axis: 1)
      |> Nx.squeeze()
      |> Nx.to_number()

    display = token_names[next_id] || "token_#{next_id}"
    IO.puts("  Step #{step}: #{display} (id=#{next_id})")

    new_tokens = tokens ++ [next_id]

    if next_id == eos_token do
      {:halt, {new_tokens, enc_out}}
    else
      {:cont, {new_tokens, enc_out}}
    end
  end)

IO.puts("\nGenerated #{length(final_tokens) - length(initial_tokens)} new tokens.")
IO.puts("Full sequence: #{inspect(final_tokens)}")
IO.puts("")
IO.puts("With real audio (not silence), the model would produce actual words.")
IO.puts("The token IDs would map to BPE subwords via Whisper's tokenizer.")

5. Understanding the Full Pipeline

IO.puts("""
=== Whisper ASR Pipeline ===

  Audio file (.wav, .mp3, etc.)
        │
        ▼
  ┌─────────────────────┐
  │ Mel Spectrogram      │  ← Not in Edifice (use Python/ffmpeg)
  │ 16kHz → STFT → mel  │     librosa.feature.melspectrogram()
  │ Output: [80, T]     │     or whisper.log_mel_spectrogram()
  └─────────────────────┘
        │
        ▼
  ┌─────────────────────┐
  │ Encoder              │  ← Edifice: Whisper.build_encoder/1
  │ 2x Conv1D (stride 2) │
  │ + sinusoidal PE      │
  │ + N transformer blocks│
  │ Output: [T/2, 384]  │
  └─────────────────────┘
        │
        ▼
  ┌─────────────────────┐
  │ Decoder              │  ← Edifice: Whisper.build_decoder/1
  │ Token embed + pos embed│
  │ + N transformer blocks│
  │   (causal self-attn  │
  │    + cross-attn)     │
  │ Output: [S, 51865]  │     logits over vocabulary
  └─────────────────────┘
        │
        ▼
  ┌─────────────────────┐
  │ Greedy / Beam Search │  ← argmax or beam search over logits
  │ Token IDs → Text     │     via Whisper BPE tokenizer
  └─────────────────────┘

To do real ASR with Edifice, you need:
  1. Audio → mel spectrogram (Python: whisper.log_mel_spectrogram)
  2. Save mel as .npy or .safetensors
  3. Load mel tensor in Elixir: Nx.from_numpy("mel.npy")
  4. Run encoder + greedy decode (as shown above)
  5. Map token IDs to text (need Whisper's BPE vocab file)

The architecture and weight loading are complete — the missing piece is
audio preprocessing, which lives outside the neural network.
""")

Summary

IO.puts("""
=== What We Covered ===

1. ARCHITECTURE: Whisper is an encoder-decoder transformer
   - Encoder: Conv stem + transformer blocks (processes mel spectrograms)
   - Decoder: Token embed + cross-attention transformer (generates text)

2. PRETRAINED WEIGHTS: One-line loading from HuggingFace
   - Edifice.Pretrained.from_hub("openai/whisper-tiny")
   - Downloads, caches, maps keys, loads into Axon model state

3. ENCODER-DECODER FLOW:
   - Encoder turns audio features into a dense representation
   - Decoder cross-attends to encoder output while generating tokens

4. GREEDY DECODING: Auto-regressive token generation
   - Feed all previous tokens → get next token logits → argmax → repeat
   - Stop at <|endoftext|> token (50257)

5. T400 FRIENDLY: whisper-tiny fits in 2GB
   - 39M params, ~150MB in f32
   - 4 encoder + 4 decoder layers, 384 hidden dim

For real ASR: compute mel spectrogram externally (Python/ffmpeg),
load as Nx tensor, run the pipeline shown here.
""")