Sequence Modeling with Edifice
Setup
Choose one of the two cells below depending on how you started Livebook.
Standalone (default)
Use this if you started Livebook normally (livebook server).
Uncomment the EXLA lines for GPU acceleration.
edifice_dep =
if File.dir?(Path.expand("~/edifice")) do
{:edifice, path: Path.expand("~/edifice")}
else
{:edifice, "~> 0.2.0"}
end
Mix.install([
edifice_dep,
# {:exla, "~> 0.10"},
{:kino_vega_lite, "~> 0.1"},
{:kino, "~> 0.14"}
])
# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
Attached to project (recommended for Nix/CUDA)
Use this if you started Livebook via ./scripts/livebook.sh.
See the Architecture Zoo notebook for full setup instructions.
Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")
Introduction
This notebook trains sequence models to predict a sine wave — a classic time-series task that tests whether a model can learn temporal patterns.
What you’ll learn:
-
How to shape data for sequence models:
{batch, seq_len, features} - Training Mamba (SSM) and GRU (recurrent) on the same task
- Comparing architectures with identical training loops
- Visualizing predictions vs ground truth
Generate Sine Wave Data
We generate a sine wave and split it into overlapping windows:
each input is seq_len timesteps, and the target is the next value.
# Parameters
n_points = 2000
seq_len = 16
freq = 0.05
# Generate sine wave with some complexity
key = Nx.Random.key(42)
ts = Nx.iota({n_points}) |> Nx.multiply(freq)
signal = Nx.add(Nx.sin(ts), Nx.multiply(Nx.sin(Nx.multiply(ts, 2.3)), 0.3))
# Create overlapping windows: [input_seq] -> [next_value]
windows = n_points - seq_len - 1
IO.puts("Creating #{windows} overlapping windows (seq_len=#{seq_len})...")
IO.puts(" Building input sequences...")
x_data =
for i <- 0..(windows - 1) do
Nx.slice(signal, [i], [seq_len]) |> Nx.reshape({seq_len, 1})
end
|> Nx.stack()
IO.puts(" Input shape: #{inspect(Nx.shape(x_data))}")
IO.puts(" Building targets...")
y_data =
for i <- 0..(windows - 1) do
Nx.slice(signal, [i + seq_len], [1])
end
|> Nx.stack()
IO.puts(" Target shape: #{inspect(Nx.shape(y_data))}")
# 80/20 split
n_train = round(windows * 0.8)
train_x = x_data[0..(n_train - 1)]
train_y = y_data[0..(n_train - 1)]
test_x = x_data[n_train..-1//1]
test_y = y_data[n_train..-1//1]
# Batch for training
batch_size = 32
IO.puts(" Batching training data...")
train_data =
Enum.zip(
Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
Nx.to_batched(train_y, batch_size) |> Enum.to_list()
)
IO.puts("Ready: #{n_train} train / #{windows - n_train} test windows, #{length(train_data)} batches/epoch")
Let’s visualize the signal and the train/test split:
signal_list = Nx.to_flat_list(signal)
chart_data =
signal_list
|> Enum.with_index()
|> Enum.map(fn {val, i} ->
split = if i < n_train + seq_len, do: "train", else: "test"
%{"t" => i, "value" => val, "split" => split}
end)
Vl.new(width: 700, height: 250, title: "Sine Wave Signal (train / test split)")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "t", type: :quantitative, title: "Timestep")
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "split", type: :nominal)
Helper: Train and Evaluate
We’ll reuse this function for each architecture.
defmodule SeqTrainer do
@doc "Train a model and return {trained_state, test_predictions, test_mse}"
def train_and_eval(model, train_data, test_x, test_y, opts \\ []) do
epochs = Keyword.get(opts, :epochs, 30)
lr = Keyword.get(opts, :lr, 1.0e-3)
trained_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: lr))
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: epochs)
{_init_fn, predict_fn} = Axon.build(model)
preds = predict_fn.(trained_state, test_x)
mse =
Nx.subtract(preds, test_y)
|> Nx.pow(2)
|> Nx.mean()
|> Nx.to_number()
{trained_state, preds, mse}
end
end
Train Mamba (State-Space Model)
Mamba is a selective state-space model — it processes sequences through learned state transitions with input-dependent gating. Fast and parallelizable.
mamba_model =
Edifice.build(:mamba,
embed_dim: 32,
hidden_size: 16,
state_size: 8,
num_layers: 2,
seq_len: seq_len,
window_size: seq_len,
dropout: 0.0
)
|> Axon.dense(1, name: "mamba_output")
{mamba_state, mamba_preds, mamba_mse} =
SeqTrainer.train_and_eval(mamba_model, train_data, test_x, test_y, epochs: 3)
IO.puts("Mamba test MSE: #{Float.round(mamba_mse, 6)}")
Train GRU (Recurrent)
GRU is a gated recurrent unit — a simpler variant of LSTM that maintains a hidden state updated at each timestep. A well-established sequence baseline.
gru_model =
Edifice.build(:gru,
embed_dim: 32,
hidden_size: 16,
num_layers: 2,
seq_len: seq_len,
window_size: seq_len,
dropout: 0.0
)
|> Axon.dense(1, name: "gru_output")
{gru_state, gru_preds, gru_mse} =
SeqTrainer.train_and_eval(gru_model, train_data, test_x, test_y, epochs: 3)
IO.puts("GRU test MSE: #{Float.round(gru_mse, 6)}")
Train GatedSSM (Lightweight SSM)
GatedSSM is a simpler gated state-space model — fewer parameters than Mamba but still captures sequential dynamics.
gated_model =
Edifice.build(:gated_ssm,
embed_dim: 32,
hidden_size: 16,
state_size: 8,
num_layers: 2,
seq_len: seq_len,
window_size: seq_len,
dropout: 0.0
)
|> Axon.dense(1, name: "gated_output")
{gated_state, gated_preds, gated_mse} =
SeqTrainer.train_and_eval(gated_model, train_data, test_x, test_y, epochs: 3)
IO.puts("GatedSSM test MSE: #{Float.round(gated_mse, 6)}")
Compare Results
IO.puts("=" |> String.duplicate(50))
IO.puts(" Architecture Test MSE")
IO.puts(" " <> String.duplicate("-", 35))
IO.puts(" Mamba #{Float.round(mamba_mse, 6)}")
IO.puts(" GRU #{Float.round(gru_mse, 6)}")
IO.puts(" GatedSSM #{Float.round(gated_mse, 6)}")
IO.puts("=" |> String.duplicate(50))
best =
[{:mamba, mamba_mse}, {:gru, gru_mse}, {:gated_ssm, gated_mse}]
|> Enum.min_by(&elem(&1, 1))
IO.puts("\nBest: #{elem(best, 0)}")
Visualize Predictions
Let’s overlay all three models’ predictions against the ground truth on the test set.
test_actual = Nx.to_flat_list(Nx.reshape(test_y, {Nx.axis_size(test_y, 0)}))
mamba_pred_list = Nx.to_flat_list(Nx.reshape(mamba_preds, {Nx.axis_size(mamba_preds, 0)}))
gru_pred_list = Nx.to_flat_list(Nx.reshape(gru_preds, {Nx.axis_size(gru_preds, 0)}))
gated_pred_list = Nx.to_flat_list(Nx.reshape(gated_preds, {Nx.axis_size(gated_preds, 0)}))
n_test = length(test_actual)
chart_data =
Enum.flat_map(0..(n_test - 1), fn i ->
[
%{"t" => i, "value" => Enum.at(test_actual, i), "series" => "Ground Truth"},
%{"t" => i, "value" => Enum.at(mamba_pred_list, i), "series" => "Mamba"},
%{"t" => i, "value" => Enum.at(gru_pred_list, i), "series" => "GRU"},
%{"t" => i, "value" => Enum.at(gated_pred_list, i), "series" => "GatedSSM"}
]
end)
Vl.new(width: 700, height: 350, title: "Test Set: Predictions vs Ground Truth")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "t", type: :quantitative, title: "Test Timestep")
|> Vl.encode_field(:y, "value", type: :quantitative, title: "Signal Value")
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
What’s Next?
-
Try more architectures: Replace
:mambawith:retnet,:s4,:xlstm,:hyena,:liquid, or any of 30+ sequence models — same training loop. - Increase complexity: Try multi-feature sequences, longer horizons, or chaotic signals like Lorenz attractors.
-
Add EXLA:
Nx.global_default_backend(EXLA.Backend)for GPU acceleration. -
Tune hyperparameters: Adjust
embed_dim,num_layers,state_size, learning rate, and epochs.