Composing Custom Architectures from Blocks
Setup
Choose one of the two cells below depending on how you started Livebook.
Standalone (default)
Use this if you started Livebook normally (livebook server).
edifice_dep =
if File.dir?(Path.expand("~/edifice")) do
{:edifice, path: Path.expand("~/edifice")}
else
{:edifice, "~> 0.2.0"}
end
Mix.install([
edifice_dep,
# {:exla, "~> 0.10"},
{:kino, "~> 0.14"}
])
# Nx.global_default_backend(EXLA.Backend)
Attached to project (recommended for Nix/CUDA)
Use this if you started Livebook via ./scripts/livebook.sh.
Nx.global_default_backend(EXLA.Backend)
IO.puts("Attached mode — using EXLA backend from project node")
Introduction
Edifice has 200+ architectures you can use out of the box with Edifice.build/2.
But the library’s real power is composition — mixing blocks from different
families to create something new. The same TransformerBlock, SDPA, RoPE,
and FFN that power the built-in architectures are available for you to combine
however you want.
What you’ll learn
-
How
TransformerBlock.layer/2separates structure (norms, residuals) from computation (attention, FFN) - How to swap attention mechanisms and FFN variants via callbacks
- How to add position encodings (RoPE vs sinusoidal) to any attention
-
How to use
ModelBuilderto get a complete model pipeline in one call - How to interleave different block types across layers (Hymba-style SSM+attention)
All examples use tiny dimensions so they run instantly on any GPU or CPU.
# Shared aliases — we'll use these throughout
alias Edifice.Blocks.{TransformerBlock, FFN, SDPA, CausalMask, RMSNorm, SwiGLU}
alias Edifice.Blocks.{RoPE, SinusoidalPE, ModelBuilder, PatchEmbed}
alias Edifice.Attention.MultiHead
# Tiny dimensions for fast iteration
batch = 2
seq_len = 16
hidden = 64
num_heads = 4
head_dim = div(hidden, num_heads)
IO.puts("Config: batch=#{batch} seq=#{seq_len} hidden=#{hidden} heads=#{num_heads} head_dim=#{head_dim}")
# Helper to build, init, and run a model — we'll reuse this everywhere
defmodule Compose do
def run(name, model, inputs) do
templates = Map.new(inputs, fn {k, v} -> {k, Nx.template(Nx.shape(v), Nx.type(v))} end)
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(templates, Axon.ModelState.empty())
output = predict_fn.(params, inputs)
param_count = count_params(params)
IO.puts("#{name}")
IO.puts(" Output: #{format_shape(output)}")
IO.puts(" Params: #{fmt(param_count)}")
{output, params, predict_fn}
end
def count_params(%Axon.ModelState{} = state) do
state |> Axon.ModelState.trainable_parameters() |> count_nested(0)
end
defp count_nested(%Nx.Tensor{} = t, acc), do: acc + Nx.size(t)
defp count_nested(map, acc) when is_map(map) do
Enum.reduce(map, acc, fn {_k, v}, a -> count_nested(v, a) end)
end
defp count_nested(_other, acc), do: acc
def fmt(n) when n >= 1_000_000, do: "#{Float.round(n / 1_000_000, 1)}M"
def fmt(n) when n >= 1_000, do: "#{Float.round(n / 1_000, 1)}K"
def fmt(n), do: "#{n}"
defp format_shape({a, b}), do: "{#{inspect(Nx.shape(a))}, #{inspect(Nx.shape(b))}}"
defp format_shape(%Nx.Tensor{} = t), do: inspect(Nx.shape(t))
defp format_shape(map) when is_map(map) do
Enum.map_join(map, ", ", fn {k, v} -> "#{k}: #{format_shape(v)}" end)
end
defp format_shape(other), do: inspect(other)
end
1. TransformerBlock — Structure vs Computation
The central idea: TransformerBlock.layer/2 handles the boring-but-important
stuff (normalization, residual connections, dropout) while you provide the
interesting part (what kind of attention to use) via a callback function.
What to look for
- The block always produces the same output shape as its input
-
The
attention_fncallback receives pre-normalized input + a name string - Different attention mechanisms slot in with zero structural changes
# The simplest possible transformer block:
# standard multi-head attention + standard FFN
input = Axon.input("x", shape: {nil, seq_len, hidden})
# attention_fn receives (normalized_input, name_prefix) and returns an Axon node.
# This is the ONLY thing you need to provide — TransformerBlock handles:
# - Pre-norm (LayerNorm by default)
# - Residual connections
# - FFN sublayer (dense → activation → dense)
# - Dropout
simple_block = TransformerBlock.layer(input,
attention_fn: fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden,
num_heads: num_heads,
dropout: 0.0,
causal: true,
name: name
)
end,
hidden_size: hidden,
name: "simple_block"
)
test_input = Nx.broadcast(0.5, {batch, seq_len, hidden})
{out, _, _} = Compose.run("Simple TransformerBlock", simple_block, %{"x" => test_input})
# Output shape matches input — blocks are stackable!
IO.puts(" Input shape: #{inspect(Nx.shape(test_input))}")
IO.puts(" Output shape: #{inspect(Nx.shape(out))}")
IO.puts(" Shapes match: #{Nx.shape(test_input) == Nx.shape(out)}")
2. Swapping Attention — Same Structure, Different Math
Now the payoff: we can plug in any attention mechanism and the surrounding structure stays identical. Let’s compare three attention variants in the same TransformerBlock skeleton.
What to look for
- All three produce the same output shape
- Parameter counts differ because the attention internals differ
-
The TransformerBlock options are identical — only
attention_fnchanges
# Build the same block with 3 different attention mechanisms
input = Axon.input("x", shape: {nil, seq_len, hidden})
# 1. Standard softmax attention (the classic)
standard = TransformerBlock.layer(input,
attention_fn: fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads, causal: true, name: name)
end,
hidden_size: hidden, name: "standard"
)
# 2. Hand-rolled sigmoid attention — replaces softmax with sigmoid.
# No softmax normalization means attention weights don't sum to 1.
# Why? Avoids the "attention sink" problem where early tokens get
# disproportionate weight just because softmax needs to distribute mass.
# (Apple's sigmoid attention paper, ICML 2025)
sigmoid = TransformerBlock.layer(input,
attention_fn: fn x, name ->
q = Axon.dense(x, hidden, name: "#{name}_q")
k = Axon.dense(x, hidden, name: "#{name}_k")
v = Axon.dense(x, hidden, name: "#{name}_v")
# Sigmoid instead of softmax — that's the only change!
Axon.layer(
fn q_val, k_val, v_val, _opts ->
{b, s, _} = Nx.shape(q_val)
scale = :math.sqrt(head_dim)
q_h = Nx.reshape(q_val, {b, s, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3])
k_h = Nx.reshape(k_val, {b, s, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3])
v_h = Nx.reshape(v_val, {b, s, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3])
scores = Nx.dot(q_h, [3], [0, 1], k_h, [3], [0, 1]) |> Nx.divide(scale)
weights = Nx.sigmoid(scores)
out = Nx.dot(weights, [3], [0, 1], v_h, [2], [0, 1])
Nx.transpose(out, axes: [0, 2, 1, 3]) |> Nx.reshape({b, s, hidden})
end,
[q, k, v],
name: "#{name}_sigmoid_sdpa"
)
end,
hidden_size: hidden, name: "sigmoid"
)
# 3. Roll your own — a minimal attention using SDPA directly.
# This shows what's happening under the hood: Q/K/V projections
# followed by scaled dot-product attention.
custom = TransformerBlock.layer(input,
attention_fn: fn x, name ->
q = Axon.dense(x, hidden, name: "#{name}_q")
k = Axon.dense(x, hidden, name: "#{name}_k")
v = Axon.dense(x, hidden, name: "#{name}_v")
Axon.layer(
fn q_val, k_val, v_val, _opts ->
# SDPA.compute handles reshaping to [batch, heads, seq, head_dim],
# scaling by 1/sqrt(head_dim), softmax, and final reshape back
SDPA.compute(q_val, k_val, v_val, num_heads, head_dim)
end,
[q, k, v],
name: "#{name}_sdpa"
)
end,
hidden_size: hidden, name: "custom"
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
for {label, model} <- [
{"Standard MHA", standard},
{"Sigmoid Attention", sigmoid},
{"Custom SDPA", custom}
] do
Compose.run(label, model, %{"x" => test_input})
end
IO.puts("\nSame structure, different attention math — that's the callback pattern.")
3. Swapping the FFN — Standard, SwiGLU, and Custom
The feed-forward network is the other major component you can swap. TransformerBlock supports three modes:
-
:standard— Dense(4H) → GELU → Dense(H) -
:gated— SwiGLU: parallel gate+up projections, element-wise multiply, down project -
:custom_ffn— Your own function
What to look for
- SwiGLU has ~33% fewer params than standard at the same expansion because it uses a 2/3 expansion factor internally (to compensate for the extra gate projection)
- Custom FFN lets you use anything — even a KAN or MLP-Mixer-style channel mixing
input = Axon.input("x", shape: {nil, seq_len, hidden})
# Same attention for all three — we're only changing the FFN
my_attn = fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads, causal: true, name: name)
end
# 1. Standard FFN: Dense(4*hidden) → GELU → Dense(hidden)
standard_ffn = TransformerBlock.layer(input,
attention_fn: my_attn,
hidden_size: hidden,
ffn_type: :standard,
ffn_expansion: 4,
name: "std_ffn"
)
# 2. SwiGLU FFN: parallel gate/up → SiLU(gate) * up → down
# Used by LLaMA, Mistral, and most modern LLMs
gated_ffn = TransformerBlock.layer(input,
attention_fn: my_attn,
hidden_size: hidden,
ffn_type: :gated,
name: "gated_ffn"
)
# 3. Custom FFN: explicitly use the FFN block with custom inner size
custom_ffn = TransformerBlock.layer(input,
attention_fn: my_attn,
hidden_size: hidden,
custom_ffn: fn x, name ->
# FFN.layer gives you full control over inner dimensions and activation
FFN.layer(x, hidden_size: hidden, inner_size: 128, activation: :relu, name: name)
end,
name: "custom_ffn"
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
for {label, model} <- [
{"Standard FFN (4x expansion)", standard_ffn},
{"SwiGLU FFN (gated)", gated_ffn},
{"Custom FFN (relu, 2x expansion)", custom_ffn}
] do
Compose.run(label, model, %{"x" => test_input})
end
4. Position Encodings — RoPE vs Sinusoidal
Position information is critical — without it, attention treats the sequence as a set (order-agnostic). Two main approaches:
-
RoPE (rotary): Applied to Q and K after projection. Encodes relative positions via rotation matrices. Extrapolates to longer sequences than seen during training. Used by LLaMA, Mistral, most modern LLMs.
-
Sinusoidal: Added to the input before attention. Encodes absolute positions via sin/cos at different frequencies. Used by the original Transformer, Whisper, DETR.
What to look for
- RoPE is applied inside the attention function (rotates Q and K)
- Sinusoidal PE is applied outside, before the transformer block
-
Both are just
Nx.Tensoroperations — they add no trainable parameters
input = Axon.input("x", shape: {nil, seq_len, hidden})
# Approach 1: RoPE — applied inside the attention callback
# Precompute cos/sin tables once, reuse across layers
{cos_table, sin_table} = RoPE.precompute_freqs(head_dim, seq_len)
rope_block = TransformerBlock.layer(input,
attention_fn: fn x, name ->
q = Axon.dense(x, hidden, name: "#{name}_q")
k = Axon.dense(x, hidden, name: "#{name}_k")
v = Axon.dense(x, hidden, name: "#{name}_v")
Axon.layer(
fn q_val, k_val, v_val, _opts ->
# Reshape to [batch, heads, seq, head_dim] for rotation
{b, s, _} = Nx.shape(q_val)
q_4d = Nx.reshape(q_val, {b, s, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3])
k_4d = Nx.reshape(k_val, {b, s, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3])
# Apply RoPE rotation — this is where relative position info enters
{q_rot, k_rot} = RoPE.apply_rotary_4d(q_4d, k_4d,
cos: cos_table, sin: sin_table)
# Flatten back to [batch, seq, hidden] and run SDPA
q_flat = q_rot |> Nx.transpose(axes: [0, 2, 1, 3]) |> Nx.reshape({b, s, hidden})
k_flat = k_rot |> Nx.transpose(axes: [0, 2, 1, 3]) |> Nx.reshape({b, s, hidden})
SDPA.compute(q_flat, k_flat, v_val, num_heads, head_dim)
end,
[q, k, v],
name: "#{name}_rope_sdpa"
)
end,
hidden_size: hidden, name: "rope_block"
)
# Approach 2: Sinusoidal PE — added to input before the block
pe_input = SinusoidalPE.layer(input, dim: hidden, max_len: seq_len, name: "sinusoidal_pe")
sinusoidal_block = TransformerBlock.layer(pe_input,
attention_fn: fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads, causal: true, name: name)
end,
hidden_size: hidden, name: "sinusoidal_block"
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
IO.puts("=== Position Encoding Comparison ===\n")
{rope_out, _, _} = Compose.run("RoPE (inside attention)", rope_block, %{"x" => test_input})
{sin_out, _, _} = Compose.run("Sinusoidal PE (added to input)", sinusoidal_block, %{"x" => test_input})
IO.puts("\nRoPE: relative positions, applied to Q/K. No trainable params for PE itself.")
IO.puts("Sinusoidal: absolute positions, added to input. Also no trainable params.")
5. Normalization — LayerNorm vs RMSNorm
TransformerBlock uses LayerNorm by default, but you can switch to RMSNorm with a single option. RMSNorm skips the centering step (mean subtraction), which makes it ~10% faster with negligible quality difference. Most modern LLMs (LLaMA, Mistral, Gemma) use RMSNorm.
input = Axon.input("x", shape: {nil, seq_len, hidden})
my_attn = fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads, causal: true, name: name)
end
# LayerNorm (default): center + scale
ln_block = TransformerBlock.layer(input,
attention_fn: my_attn,
hidden_size: hidden,
norm: :layer_norm,
name: "layernorm_block"
)
# RMSNorm: scale only (no centering), faster
rms_block = TransformerBlock.layer(input,
attention_fn: my_attn,
hidden_size: hidden,
norm: :rms_norm,
name: "rmsnorm_block"
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
Compose.run("LayerNorm block", ln_block, %{"x" => test_input})
Compose.run("RMSNorm block", rms_block, %{"x" => test_input})
6. Stacking — From One Block to N Layers
A single block is nice, but real models stack 4-48 of them.
TransformerBlock.stack/3 does this with auto-generated names.
What to look for
- Params scale linearly with layer count (each block has its own weights)
- The output shape stays the same regardless of depth — blocks are uniform
input = Axon.input("x", shape: {nil, seq_len, hidden})
my_attn = fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads, causal: true, name: name)
end
# Stack 1, 2, and 4 layers — same attention, different depths
for n_layers <- [1, 2, 4] do
model = TransformerBlock.stack(input, n_layers,
attention_fn: my_attn,
hidden_size: hidden,
name: "stack"
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
Compose.run("#{n_layers}-layer stack", model, %{"x" => test_input})
end
IO.puts("\nParams scale linearly: each block adds the same number of weights.")
7. ModelBuilder — Complete Model in One Call
So far we’ve been working with raw blocks. ModelBuilder.build_sequence_model/1
wraps everything into a complete pipeline:
-
Input node (
"state_sequence") -
Linear projection (if
embed_dim != hidden_size) -
N stacked blocks via your
block_buildercallback - Final layer normalization
- Output extraction (last timestep, all timesteps, or mean pooling)
What to look for
-
The
block_buildercallback receivesopts[:layer_idx]so you can name blocks uniquely -
output_mode: :last_timestepextracts[batch, hidden]from the final position - This is exactly how built-in architectures like GatedAttention, FoX, and LASER are implemented
# A complete sequence model using ModelBuilder
model = ModelBuilder.build_sequence_model(
embed_dim: 48, # Input dimension (e.g., from an embedding layer)
hidden_size: hidden, # Internal dimension (auto-projects from 48 → 64)
num_layers: 3,
seq_len: seq_len,
block_builder: fn input, opts ->
idx = opts[:layer_idx]
TransformerBlock.layer(input,
attention_fn: fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads,
causal: true, dropout: 0.0, name: name)
end,
hidden_size: hidden,
ffn_type: :gated, # SwiGLU FFN
norm: :rms_norm, # Modern normalization
name: "block_#{idx}"
)
end,
output_mode: :last_timestep # [batch, hidden] — good for classification/RL
)
# Note: embed_dim (48) != hidden_size (64), so ModelBuilder adds a projection layer
test_input = Nx.broadcast(0.3, {batch, seq_len, 48})
{out, _, _} = Compose.run("ModelBuilder sequence model", model, %{
"state_sequence" => test_input
})
IO.puts(" Input: [#{batch}, #{seq_len}, 48] (embed_dim)")
IO.puts(" Output: #{inspect(Nx.shape(out))} (last_timestep → [batch, hidden])")
IO.puts("\nThis is a complete LLaMA-style model: RMSNorm + SwiGLU + causal MHA.")
IO.puts("Add RoPE in the attention callback and you'd match the real architecture.")
8. Recipe: Mixing Block Types Across Layers
Some of the best architectures interleave different block types. Hymba
alternates Mamba SSM blocks with attention blocks. Griffin uses recurrent
layers for most of the stack and attention for a few. The block_builder
callback makes this trivial — just branch on layer_idx.
What to look for
- Odd layers use a simple “SSM-like” block (depthwise conv + gating, no attention)
- Even layers use standard transformer blocks (attention + FFN)
- The model seamlessly mixes both because they share the same shape contract
# Hybrid model: alternating attention and "SSM-like" blocks
# Real SSM blocks (Mamba.block) are available but we use a lightweight
# conv+gate substitute here to keep things simple and fast.
model = ModelBuilder.build_sequence_model(
embed_dim: hidden,
hidden_size: hidden,
num_layers: 4,
seq_len: seq_len,
block_builder: fn input, opts ->
idx = opts[:layer_idx]
if rem(idx, 2) == 0 do
# Even layers: standard attention block
TransformerBlock.layer(input,
attention_fn: fn x, name ->
MultiHead.self_attention(x,
hidden_size: hidden, num_heads: num_heads,
causal: true, name: name)
end,
hidden_size: hidden,
ffn_type: :gated,
name: "attn_#{idx}"
)
else
# Odd layers: conv + gate block (lightweight SSM substitute)
# This mimics what Mamba/Griffin do: local context via conv,
# gated output, no attention. In production you'd use Mamba.block/2.
normed = Axon.layer_norm(input, name: "ssm_norm_#{idx}")
gate = Axon.dense(normed, hidden, name: "ssm_gate_#{idx}", activation: :sigmoid)
up = Axon.dense(normed, hidden, name: "ssm_up_#{idx}", activation: :silu)
gated = Axon.layer(
fn g, u, _opts -> Nx.multiply(g, u) end,
[gate, up],
name: "ssm_gated_#{idx}"
)
down = Axon.dense(gated, hidden, name: "ssm_down_#{idx}")
# Residual connection
Axon.layer(
fn residual, projected, _opts -> Nx.add(residual, projected) end,
[input, down],
name: "ssm_residual_#{idx}"
)
end
end,
output_mode: :last_timestep
)
test_input = Nx.broadcast(0.3, {batch, seq_len, hidden})
{out, _, _} = Compose.run("Hybrid attention+SSM model (4 layers)", model, %{
"state_sequence" => test_input
})
IO.puts("\n Layer 0: Attention Layer 1: Conv+Gate")
IO.puts(" Layer 2: Attention Layer 3: Conv+Gate")
IO.puts("\nThis is the Hymba/Griffin pattern: attention for global context,")
IO.puts("recurrence for local context, interleaved for best of both worlds.")
9. Recipe: Encoder-Decoder with Cross-Attention
Encoder-decoder models use TransformerBlock.layer/3 (3-sublayer variant)
and CrossAttention.layer/3. The decoder attends to its own sequence
(causal self-attention) and to the encoder’s output (cross-attention).
This pattern powers Whisper, DETR, and any seq2seq architecture.
What to look for
-
The encoder stack uses
stack/3(2-sublayer, no cross-attention) -
The decoder stack uses
stack/4— the extra argument is the encoder output - Cross-attention Q comes from the decoder, K/V come from the encoder
alias Edifice.Blocks.CrossAttention
enc_hidden = hidden
dec_hidden = hidden
n_layers = 2
# Encoder: standard transformer blocks (no causal mask — bidirectional)
encoder_input = Axon.input("encoder_input", shape: {nil, seq_len, enc_hidden})
encoded = TransformerBlock.stack(encoder_input, n_layers,
attention_fn: fn x, name ->
# Non-causal: encoder sees the full input
MultiHead.self_attention(x,
hidden_size: enc_hidden, num_heads: num_heads,
causal: false, name: name)
end,
hidden_size: enc_hidden,
name: "encoder"
)
# Decoder: self-attention (causal) + cross-attention to encoder output
decoder_input = Axon.input("decoder_input", shape: {nil, seq_len, dec_hidden})
decoded = TransformerBlock.stack(decoder_input, encoded, n_layers,
attention_fn: fn x, name ->
# Causal: decoder can only see previous positions
MultiHead.self_attention(x,
hidden_size: dec_hidden, num_heads: num_heads,
causal: true, name: name)
end,
cross_attention_fn: fn query, memory, name ->
# Q from decoder, K/V from encoder — this is the "cross" part
CrossAttention.layer(query, memory,
hidden_size: dec_hidden, num_heads: num_heads, name: name)
end,
hidden_size: dec_hidden,
name: "decoder"
)
# Output both for inspection
model = Axon.container(%{encoder: encoded, decoder: decoded})
enc_input = Nx.broadcast(0.5, {batch, seq_len, enc_hidden})
dec_input = Nx.broadcast(0.3, {batch, seq_len, dec_hidden})
inputs = %{"encoder_input" => enc_input, "decoder_input" => dec_input}
templates = Map.new(inputs, fn {k, v} -> {k, Nx.template(Nx.shape(v), Nx.type(v))} end)
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(templates, Axon.ModelState.empty())
%{encoder: enc_out, decoder: dec_out} = predict_fn.(params, inputs)
param_count = Compose.count_params(params)
IO.puts("Encoder-Decoder (#{n_layers} layers each)")
IO.puts(" Encoder output: #{inspect(Nx.shape(enc_out))}")
IO.puts(" Decoder output: #{inspect(Nx.shape(dec_out))}")
IO.puts(" Total params: #{Compose.fmt(param_count)}")
IO.puts("\nThe decoder's cross-attention layers attend to the encoder output.")
IO.puts("This is how Whisper transcribes audio → text.")
Summary
IO.puts("""
=== Composition Cheat Sheet ===
BUILDING BLOCKS:
TransformerBlock.layer/2 Norm → Attention → Residual → Norm → FFN → Residual
TransformerBlock.layer/3 Same + cross-attention sublayer (encoder-decoder)
TransformerBlock.stack/3 Repeat block N times (self-attention only)
TransformerBlock.stack/4 Repeat block N times (with cross-attention)
ModelBuilder.build_sequence_model Complete pipeline: input → project → blocks → norm → output
ModelBuilder.build_vision_model Complete pipeline: input → patch embed → blocks → pool → classify
CALLBACKS:
attention_fn: fn x, name -> ... end Plug in any attention mechanism
cross_attention_fn: fn q, memory, name -> end Plug in cross-attention (for enc-dec)
custom_ffn: fn x, name -> ... end Replace the FFN sublayer
block_builder: fn input, opts -> ... end Plug in any block type (for ModelBuilder)
OPTIONS:
ffn_type: :standard | :gated Standard or SwiGLU FFN
norm: :layer_norm | :rms_norm Normalization type
output_mode: :last_timestep | :all | :mean_pool How to extract the final output
SHARED PRIMITIVES:
SDPA.compute/5 Scaled dot-product attention (handles reshaping + scaling)
RoPE.apply_rotary_4d/2 Rotary position embeddings (relative positions)
SinusoidalPE.layer/2 Sinusoidal position encoding (absolute positions)
FFN.layer/2 Standard FFN with configurable inner size
SwiGLU.layer/2 Gated FFN (LLaMA-style)
RMSNorm.layer/2 Root-mean-square normalization
CrossAttention.layer/3 Cross-attention between two sequences
CausalMask.causal/1 Lower-triangular attention mask
The pattern: structure from TransformerBlock, computation from callbacks,
primitives from shared blocks. Learn it once, compose anything.
""")