State-Space Models with eXMC
A Practical Guide: From Random Walk to Stochastic Volatility
Mix.install([
{:exmc, path: Path.expand("../", __DIR__)},
{:exla, "~> 0.10"},
{:jose, "~> 1.11"},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"}
])
The Recipe
Every state-space model in eXMC follows four steps:
-
Register data —
Builder.data(ir, y_tensor) -
Priors on parameters —
Builder.rv("sigma", HalfNormal, ...) -
Latent state process —
Builder.rv("state", GaussianRandomWalk, ...) -
Observation likelihood —
Custom.rv(ir, "lik", ...)+Builder.obs(...)
NUTS samples the entire state trajectory jointly with the parameters. No Kalman filter needed — the gradient does the work.
Helpers
alias Exmc.{Builder, Sampler}
alias Exmc.Dist.{Normal, HalfNormal, GaussianRandomWalk, Custom}
alias VegaLite, as: Vl
defmodule SSM do
def t(v), do: Nx.tensor(v, type: :f64)
def normal_logpdf_vec(x, mu, sigma) do
z = Nx.divide(Nx.subtract(x, mu), sigma)
Nx.sum(Nx.subtract(Nx.multiply(t(-0.5), Nx.multiply(z, z)), Nx.log(sigma)))
end
end
Synthetic Data
We generate data from a known process so we can verify the model recovers the truth. The process: a slowly wandering trend plus observation noise, with a volatility spike in the middle.
n = 150
:rand.seed(:exsss, {42, 43, 44})
# True trend: random walk with sigma=0.3
true_trend = Enum.scan(1..n, 5.0, fn _, prev -> prev + :rand.normal() * 0.3 end)
# True volatility: low (0.5), spikes to 2.0 in the middle, back to 0.5
true_vol = for i <- 1..n do
if i > 50 and i < 80, do: 2.0, else: 0.5
end
# Observations: trend + time-varying noise
y_list = Enum.zip(true_trend, true_vol)
|> Enum.map(fn {trend, vol} -> trend + :rand.normal() * vol end)
y = Nx.tensor(y_list, type: :f64)
IO.puts("#{n} observations generated. True sigma_trend=0.3, sigma_obs varies 0.5-2.0")
# Plot: data + true trend + true volatility
data_points =
Enum.zip([1..n, y_list, true_trend, true_vol])
|> Enum.flat_map(fn {i, obs, trend, vol} ->
[%{"t" => i, "value" => obs, "series" => "Observed"},
%{"t" => i, "value" => trend, "series" => "True Trend"},
%{"t" => i, "value" => vol, "series" => "True Volatility"}]
end)
Vl.new(width: 700, height: 300, title: "Synthetic Data: Trend + Variable Noise")
|> Vl.data_from_values(Enum.reject(data_points, & &1["series"] == "True Volatility"))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
Model 1: Local Level (Constant Volatility)
The simplest state-space model. Assumes observation noise is constant.
trend_{t+1} = trend_t + η_t, η ~ N(0, σ_trend²)
y_t = trend_t + ε_t, ε ~ N(0, σ_obs²)
ir_ll = Builder.new_ir()
|> Builder.data(y)
|> Builder.rv("sigma_trend", HalfNormal, %{sigma: SSM.t(2.0)})
|> Builder.rv("sigma_obs", HalfNormal, %{sigma: SSM.t(2.0)})
|> Builder.rv("trend", GaussianRandomWalk, %{sigma: "sigma_trend"}, shape: {n})
ll_logpdf = fn _x, params ->
obs = params.__obs_data
SSM.normal_logpdf_vec(obs, params.trend, params.sigma_obs)
end
dist_ll = Custom.new(ll_logpdf, support: :real)
ir_ll = Custom.rv(ir_ll, "lik", dist_ll, %{
trend: "trend", sigma_obs: "sigma_obs", __obs_data: "__obs_data"
})
|> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))
{trace_ll, stats_ll} = Sampler.sample(ir_ll,
%{"sigma_trend" => 0.5, "sigma_obs" => 1.0, "trend" => y},
num_warmup: 500, num_samples: 500)
st = Nx.mean(trace_ll["sigma_trend"]) |> Nx.to_number()
so = Nx.mean(trace_ll["sigma_obs"]) |> Nx.to_number()
IO.puts("Local Level: sigma_trend=#{Float.round(st, 3)} (true: 0.3), sigma_obs=#{Float.round(so, 3)} (true: 0.5-2.0)")
IO.puts("Divergences: #{stats_ll.divergences}")
trend_ll = trace_ll["trend"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()
ll_plot = Enum.zip([1..n, y_list, trend_ll, true_trend])
|> Enum.flat_map(fn {i, obs, est, truth} ->
[%{"t" => i, "value" => obs, "series" => "Data"},
%{"t" => i, "value" => est, "series" => "Local Level Trend"},
%{"t" => i, "value" => truth, "series" => "True Trend"}]
end)
Vl.new(width: 700, height: 300, title: "Model 1: Local Level (constant noise assumed)")
|> Vl.data_from_values(ll_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
Observation: The local level trend is too smooth during the quiet periods
(because sigma_obs is inflated by the noisy middle section) and not smooth
enough during the noisy period (because it’s using the same sigma everywhere).
The model doesn’t know that volatility changed.
Model 2: Stochastic Volatility
Now let the observation noise vary over time. Two latent states: trend AND log-volatility.
trend_{t+1} = trend_t + η_t, η ~ N(0, σ_trend²)
h_{t+1} = h_t + γ_t, γ ~ N(0, σ_h²)
y_t = trend_t + exp(h_t / 2) · ε_t
ir_sv = Builder.new_ir()
|> Builder.data(y)
|> Builder.rv("sigma_trend", HalfNormal, %{sigma: SSM.t(2.0)})
|> Builder.rv("sigma_h", HalfNormal, %{sigma: SSM.t(1.0)})
|> Builder.rv("trend", GaussianRandomWalk, %{sigma: "sigma_trend"}, shape: {n})
|> Builder.rv("log_vol", GaussianRandomWalk, %{sigma: "sigma_h"}, shape: {n})
sv_logpdf = fn _x, params ->
obs = params.__obs_data
sigma_t = Nx.exp(Nx.divide(params.log_vol, SSM.t(2.0)))
sigma_t = Nx.max(sigma_t, SSM.t(1.0e-6))
SSM.normal_logpdf_vec(obs, params.trend, sigma_t)
end
dist_sv = Custom.new(sv_logpdf, support: :real)
ir_sv = Custom.rv(ir_sv, "lik", dist_sv, %{
trend: "trend", log_vol: "log_vol", __obs_data: "__obs_data"
})
|> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))
{trace_sv, stats_sv} = Sampler.sample(ir_sv,
%{"sigma_trend" => 0.3, "sigma_h" => 0.2,
"trend" => y,
"log_vol" => Nx.broadcast(Nx.tensor(0.0, type: :f64), {n})},
num_warmup: 800, num_samples: 500)
IO.puts("Stochastic Volatility:")
IO.puts(" sigma_trend = #{Nx.mean(trace_sv["sigma_trend"]) |> Nx.to_number() |> Float.round(3)} (true: 0.3)")
IO.puts(" sigma_h = #{Nx.mean(trace_sv["sigma_h"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts(" divergences = #{stats_sv.divergences}")
trend_sv = trace_sv["trend"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()
log_vol_mean = trace_sv["log_vol"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()
vol_sv = Enum.map(log_vol_mean, fn h -> :math.exp(h / 2) end)
# Trend comparison
trend_plot = Enum.zip([1..n, y_list, trend_sv, true_trend])
|> Enum.flat_map(fn {i, obs, est, truth} ->
[%{"t" => i, "value" => obs, "series" => "Data"},
%{"t" => i, "value" => est, "series" => "SV Trend"},
%{"t" => i, "value" => truth, "series" => "True Trend"}]
end)
Vl.new(width: 700, height: 300, title: "Model 2: Stochastic Volatility — Trend")
|> Vl.data_from_values(trend_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
# Volatility recovery
vol_plot = Enum.zip([1..n, vol_sv, true_vol])
|> Enum.flat_map(fn {i, est, truth} ->
[%{"t" => i, "value" => est, "series" => "Estimated σ(t)"},
%{"t" => i, "value" => truth, "series" => "True σ(t)"}]
end)
Vl.new(width: 700, height: 200, title: "Model 2: Recovered Volatility exp(h_t/2)")
|> Vl.data_from_values(vol_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative, title: "σ(t)")
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
Key result: The model discovers the volatility spike (t=50-80) without being told when it occurred. The SV trend is smoother during quiet periods and tracks the data more closely during noisy periods — because it knows the noise level at each time point.
Model 3: AR(1) Process
Autoregressive: the current value depends on the previous value with
coefficient phi.
y_t = phi · y_{t-1} + ε_t, ε ~ N(0, σ²)
# Generate AR(1) data
:rand.seed(:exsss, {100, 101, 102})
true_phi = 0.8
true_sigma_ar = 1.0
ar_data = Enum.scan(1..200, 0.0, fn _, prev ->
true_phi * prev + :rand.normal() * true_sigma_ar
end)
y_ar = Nx.tensor(ar_data, type: :f64)
IO.puts("AR(1) data: phi=#{true_phi}, sigma=#{true_sigma_ar}, n=#{length(ar_data)}")
ir_ar = Builder.new_ir()
|> Builder.data(y_ar)
|> Builder.rv("phi", Normal, %{mu: SSM.t(0.0), sigma: SSM.t(1.0)})
|> Builder.rv("sigma", HalfNormal, %{sigma: SSM.t(2.0)})
ar1_logpdf = fn _x, params ->
obs = params.__obs_data
phi = params.phi
sigma = Nx.max(params.sigma, SSM.t(1.0e-6))
n_obs = Nx.axis_size(obs, 0)
y_prev = obs[0..(n_obs - 2)]
y_curr = obs[1..(n_obs - 1)]
mu = Nx.multiply(phi, y_prev)
SSM.normal_logpdf_vec(y_curr, mu, sigma)
end
dist_ar = Custom.new(ar1_logpdf, support: :real)
ir_ar = Custom.rv(ir_ar, "lik", dist_ar, %{
phi: "phi", sigma: "sigma", __obs_data: "__obs_data"
})
|> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))
{trace_ar, stats_ar} = Sampler.sample(ir_ar,
%{"phi" => 0.5, "sigma" => 1.0},
num_warmup: 500, num_samples: 500)
phi_mean = Nx.mean(trace_ar["phi"]) |> Nx.to_number()
sigma_mean = Nx.mean(trace_ar["sigma"]) |> Nx.to_number()
IO.puts("AR(1) posterior:")
IO.puts(" phi = #{Float.round(phi_mean, 3)} (true: #{true_phi})")
IO.puts(" sigma = #{Float.round(sigma_mean, 3)} (true: #{true_sigma_ar})")
IO.puts(" divergences = #{stats_ar.divergences}")
Note: The AR(1) model has only 2 free parameters (phi, sigma) — no latent state vector. NUTS samples this in under a second. The posterior of phi tells you how persistent the process is.
Comparing the Models
IO.puts("=== Model Comparison ===")
IO.puts("")
IO.puts("Model 1 (Local Level):")
IO.puts(" Parameters: sigma_trend + sigma_obs + trend[1:#{n}] = #{n + 2}")
IO.puts(" sigma_trend = #{Nx.mean(trace_ll["sigma_trend"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts(" sigma_obs = #{Nx.mean(trace_ll["sigma_obs"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts(" div = #{stats_ll.divergences}")
IO.puts("")
IO.puts("Model 2 (Stochastic Volatility):")
IO.puts(" Parameters: sigma_trend + sigma_h + trend[1:#{n}] + log_vol[1:#{n}] = #{2*n + 2}")
IO.puts(" sigma_trend = #{Nx.mean(trace_sv["sigma_trend"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts(" sigma_h = #{Nx.mean(trace_sv["sigma_h"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts(" div = #{stats_sv.divergences}")
IO.puts("")
IO.puts("Model 3 (AR(1)):")
IO.puts(" Parameters: phi + sigma = 2")
IO.puts(" phi = #{phi_mean |> Float.round(3)}")
IO.puts(" sigma = #{sigma_mean |> Float.round(3)}")
IO.puts(" div = #{stats_ar.divergences}")
Summary
| Model | State equation | When to use |
|---|---|---|
| Local level |
x_{t+1} = x_t + η |
Smooth trend extraction |
| Stochastic vol |
x_t + h_t (two GRWs) |
Financial data, crisis detection |
| AR(1) |
y_t = φ·y_{t-1} + ε |
Persistent but stationary processes |
| Regime-switching | Mixture of Normals | Market regimes, structural breaks |
The eXMC pattern is always the same: Builder.data for observations,
GaussianRandomWalk for latent states, Custom for the likelihood,
Builder.obs to mark it observed. NUTS handles the rest.
References
- Harvey, A.C. (1989). Forecasting, Structural Time Series Models and the Kalman Filter. Cambridge University Press.
- Kim, Shephard & Chib (1998). “Stochastic Volatility.” Review of Economic Studies.
- Durbin & Koopman (2012). Time Series Analysis by State Space Methods. OUP.