Powered by AppSignal & Oban Pro

Sequence Modeling with Edifice

notebooks/sequence_modeling.livemd

Sequence Modeling with Edifice

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server). Uncomment the EXLA lines for GPU acceleration.

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino_vega_lite, "~> 0.1"},
  {:kino, "~> 0.14"}
])

# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh. See the Architecture Zoo notebook for full setup instructions.

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")

Introduction

This notebook trains sequence models to predict a sine wave — a classic time-series task that tests whether a model can learn temporal patterns.

What you’ll learn:

  • How to shape data for sequence models: {batch, seq_len, features}
  • Training Mamba (SSM) and GRU (recurrent) on the same task
  • Comparing architectures with identical training loops
  • Visualizing predictions vs ground truth

Generate Sine Wave Data

We generate a sine wave and split it into overlapping windows: each input is seq_len timesteps, and the target is the next value.

# Parameters
n_points = 2000
seq_len = 16
freq = 0.05

# Generate sine wave with some complexity
key = Nx.Random.key(42)

ts = Nx.iota({n_points}) |> Nx.multiply(freq)
signal = Nx.add(Nx.sin(ts), Nx.multiply(Nx.sin(Nx.multiply(ts, 2.3)), 0.3))

# Create overlapping windows: [input_seq] -> [next_value]
windows = n_points - seq_len - 1
IO.puts("Creating #{windows} overlapping windows (seq_len=#{seq_len})...")

IO.puts("  Building input sequences...")
x_data =
  for i <- 0..(windows - 1) do
    Nx.slice(signal, [i], [seq_len]) |> Nx.reshape({seq_len, 1})
  end
  |> Nx.stack()
IO.puts("  Input shape: #{inspect(Nx.shape(x_data))}")

IO.puts("  Building targets...")
y_data =
  for i <- 0..(windows - 1) do
    Nx.slice(signal, [i + seq_len], [1])
  end
  |> Nx.stack()
IO.puts("  Target shape: #{inspect(Nx.shape(y_data))}")

# 80/20 split
n_train = round(windows * 0.8)

train_x = x_data[0..(n_train - 1)]
train_y = y_data[0..(n_train - 1)]
test_x = x_data[n_train..-1//1]
test_y = y_data[n_train..-1//1]

# Batch for training
batch_size = 32

IO.puts("  Batching training data...")
train_data =
  Enum.zip(
    Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
    Nx.to_batched(train_y, batch_size) |> Enum.to_list()
  )

IO.puts("Ready: #{n_train} train / #{windows - n_train} test windows, #{length(train_data)} batches/epoch")

Let’s visualize the signal and the train/test split:

signal_list = Nx.to_flat_list(signal)

chart_data =
  signal_list
  |> Enum.with_index()
  |> Enum.map(fn {val, i} ->
    split = if i < n_train + seq_len, do: "train", else: "test"
    %{"t" => i, "value" => val, "split" => split}
  end)

Vl.new(width: 700, height: 250, title: "Sine Wave Signal (train / test split)")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "t", type: :quantitative, title: "Timestep")
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "split", type: :nominal)

Helper: Train and Evaluate

We’ll reuse this function for each architecture.

defmodule SeqTrainer do
  @doc "Train a model and return {trained_state, test_predictions, test_mse}"
  def train_and_eval(model, train_data, test_x, test_y, opts \\ []) do
    epochs = Keyword.get(opts, :epochs, 30)
    lr = Keyword.get(opts, :lr, 1.0e-3)

    trained_state =
      model
      |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: lr))
      |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: epochs)

    {_init_fn, predict_fn} = Axon.build(model)
    preds = predict_fn.(trained_state, test_x)

    mse =
      Nx.subtract(preds, test_y)
      |> Nx.pow(2)
      |> Nx.mean()
      |> Nx.to_number()

    {trained_state, preds, mse}
  end
end

Train Mamba (State-Space Model)

Mamba is a selective state-space model — it processes sequences through learned state transitions with input-dependent gating. Fast and parallelizable.

mamba_model =
  Edifice.build(:mamba,
    embed_dim: 32,
    hidden_size: 16,
    state_size: 8,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.0
  )
  |> Axon.dense(1, name: "mamba_output")

{mamba_state, mamba_preds, mamba_mse} =
  SeqTrainer.train_and_eval(mamba_model, train_data, test_x, test_y, epochs: 3)

IO.puts("Mamba test MSE: #{Float.round(mamba_mse, 6)}")

Train GRU (Recurrent)

GRU is a gated recurrent unit — a simpler variant of LSTM that maintains a hidden state updated at each timestep. A well-established sequence baseline.

gru_model =
  Edifice.build(:gru,
    embed_dim: 32,
    hidden_size: 16,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.0
  )
  |> Axon.dense(1, name: "gru_output")

{gru_state, gru_preds, gru_mse} =
  SeqTrainer.train_and_eval(gru_model, train_data, test_x, test_y, epochs: 3)

IO.puts("GRU test MSE: #{Float.round(gru_mse, 6)}")

Train GatedSSM (Lightweight SSM)

GatedSSM is a simpler gated state-space model — fewer parameters than Mamba but still captures sequential dynamics.

gated_model =
  Edifice.build(:gated_ssm,
    embed_dim: 32,
    hidden_size: 16,
    state_size: 8,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.0
  )
  |> Axon.dense(1, name: "gated_output")

{gated_state, gated_preds, gated_mse} =
  SeqTrainer.train_and_eval(gated_model, train_data, test_x, test_y, epochs: 3)

IO.puts("GatedSSM test MSE: #{Float.round(gated_mse, 6)}")

Compare Results

IO.puts("=" |> String.duplicate(50))
IO.puts("  Architecture      Test MSE")
IO.puts("  " <> String.duplicate("-", 35))
IO.puts("  Mamba             #{Float.round(mamba_mse, 6)}")
IO.puts("  GRU               #{Float.round(gru_mse, 6)}")
IO.puts("  GatedSSM          #{Float.round(gated_mse, 6)}")
IO.puts("=" |> String.duplicate(50))

best =
  [{:mamba, mamba_mse}, {:gru, gru_mse}, {:gated_ssm, gated_mse}]
  |> Enum.min_by(&amp;elem(&amp;1, 1))

IO.puts("\nBest: #{elem(best, 0)}")

Visualize Predictions

Let’s overlay all three models’ predictions against the ground truth on the test set.

test_actual = Nx.to_flat_list(Nx.reshape(test_y, {Nx.axis_size(test_y, 0)}))
mamba_pred_list = Nx.to_flat_list(Nx.reshape(mamba_preds, {Nx.axis_size(mamba_preds, 0)}))
gru_pred_list = Nx.to_flat_list(Nx.reshape(gru_preds, {Nx.axis_size(gru_preds, 0)}))
gated_pred_list = Nx.to_flat_list(Nx.reshape(gated_preds, {Nx.axis_size(gated_preds, 0)}))

n_test = length(test_actual)

chart_data =
  Enum.flat_map(0..(n_test - 1), fn i ->
    [
      %{"t" => i, "value" => Enum.at(test_actual, i), "series" => "Ground Truth"},
      %{"t" => i, "value" => Enum.at(mamba_pred_list, i), "series" => "Mamba"},
      %{"t" => i, "value" => Enum.at(gru_pred_list, i), "series" => "GRU"},
      %{"t" => i, "value" => Enum.at(gated_pred_list, i), "series" => "GatedSSM"}
    ]
  end)

Vl.new(width: 700, height: 350, title: "Test Set: Predictions vs Ground Truth")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "t", type: :quantitative, title: "Test Timestep")
|> Vl.encode_field(:y, "value", type: :quantitative, title: "Signal Value")
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)

What’s Next?

  • Try more architectures: Replace :mamba with :retnet, :s4, :xlstm, :hyena, :liquid, or any of 30+ sequence models — same training loop.
  • Increase complexity: Try multi-feature sequences, longer horizons, or chaotic signals like Lorenz attractors.
  • Add EXLA: Nx.global_default_backend(EXLA.Backend) for GPU acceleration.
  • Tune hyperparameters: Adjust embed_dim, num_layers, state_size, learning rate, and epochs.