Powered by AppSignal & Oban Pro

State-Space Models with eXMC

notebooks/17_state_space.livemd

State-Space Models with eXMC

A Practical Guide: From Random Walk to Stochastic Volatility

Mix.install([
  {:exmc, path: Path.expand("../", __DIR__)},
  {:exla, "~> 0.10"},
  {:jose, "~> 1.11"},
  {:kino, "~> 0.14"},
  {:kino_vega_lite, "~> 0.1"}
])

The Recipe

Every state-space model in eXMC follows four steps:

  1. Register dataBuilder.data(ir, y_tensor)
  2. Priors on parametersBuilder.rv("sigma", HalfNormal, ...)
  3. Latent state processBuilder.rv("state", GaussianRandomWalk, ...)
  4. Observation likelihoodCustom.rv(ir, "lik", ...) + Builder.obs(...)

NUTS samples the entire state trajectory jointly with the parameters. No Kalman filter needed — the gradient does the work.

Helpers

alias Exmc.{Builder, Sampler}
alias Exmc.Dist.{Normal, HalfNormal, GaussianRandomWalk, Custom}
alias VegaLite, as: Vl

defmodule SSM do
  def t(v), do: Nx.tensor(v, type: :f64)

  def normal_logpdf_vec(x, mu, sigma) do
    z = Nx.divide(Nx.subtract(x, mu), sigma)
    Nx.sum(Nx.subtract(Nx.multiply(t(-0.5), Nx.multiply(z, z)), Nx.log(sigma)))
  end
end

Synthetic Data

We generate data from a known process so we can verify the model recovers the truth. The process: a slowly wandering trend plus observation noise, with a volatility spike in the middle.

n = 150
:rand.seed(:exsss, {42, 43, 44})

# True trend: random walk with sigma=0.3
true_trend = Enum.scan(1..n, 5.0, fn _, prev -> prev + :rand.normal() * 0.3 end)

# True volatility: low (0.5), spikes to 2.0 in the middle, back to 0.5
true_vol = for i <- 1..n do
  if i > 50 and i < 80, do: 2.0, else: 0.5
end

# Observations: trend + time-varying noise
y_list = Enum.zip(true_trend, true_vol)
  |> Enum.map(fn {trend, vol} -> trend + :rand.normal() * vol end)

y = Nx.tensor(y_list, type: :f64)

IO.puts("#{n} observations generated. True sigma_trend=0.3, sigma_obs varies 0.5-2.0")
# Plot: data + true trend + true volatility
data_points =
  Enum.zip([1..n, y_list, true_trend, true_vol])
  |> Enum.flat_map(fn {i, obs, trend, vol} ->
    [%{"t" => i, "value" => obs, "series" => "Observed"},
     %{"t" => i, "value" => trend, "series" => "True Trend"},
     %{"t" => i, "value" => vol, "series" => "True Volatility"}]
  end)

Vl.new(width: 700, height: 300, title: "Synthetic Data: Trend + Variable Noise")
|> Vl.data_from_values(Enum.reject(data_points, &amp; &amp;1["series"] == "True Volatility"))
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)

Model 1: Local Level (Constant Volatility)

The simplest state-space model. Assumes observation noise is constant.

trend_{t+1} = trend_t + η_t,    η ~ N(0, σ_trend²)
y_t = trend_t + ε_t,            ε ~ N(0, σ_obs²)
ir_ll = Builder.new_ir()
  |> Builder.data(y)
  |> Builder.rv("sigma_trend", HalfNormal, %{sigma: SSM.t(2.0)})
  |> Builder.rv("sigma_obs", HalfNormal, %{sigma: SSM.t(2.0)})
  |> Builder.rv("trend", GaussianRandomWalk, %{sigma: "sigma_trend"}, shape: {n})

ll_logpdf = fn _x, params ->
  obs = params.__obs_data
  SSM.normal_logpdf_vec(obs, params.trend, params.sigma_obs)
end

dist_ll = Custom.new(ll_logpdf, support: :real)

ir_ll = Custom.rv(ir_ll, "lik", dist_ll, %{
    trend: "trend", sigma_obs: "sigma_obs", __obs_data: "__obs_data"
  })
  |> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))
{trace_ll, stats_ll} = Sampler.sample(ir_ll,
  %{"sigma_trend" => 0.5, "sigma_obs" => 1.0, "trend" => y},
  num_warmup: 500, num_samples: 500)

st = Nx.mean(trace_ll["sigma_trend"]) |> Nx.to_number()
so = Nx.mean(trace_ll["sigma_obs"]) |> Nx.to_number()
IO.puts("Local Level: sigma_trend=#{Float.round(st, 3)} (true: 0.3), sigma_obs=#{Float.round(so, 3)} (true: 0.5-2.0)")
IO.puts("Divergences: #{stats_ll.divergences}")
trend_ll = trace_ll["trend"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()

ll_plot = Enum.zip([1..n, y_list, trend_ll, true_trend])
  |> Enum.flat_map(fn {i, obs, est, truth} ->
    [%{"t" => i, "value" => obs, "series" => "Data"},
     %{"t" => i, "value" => est, "series" => "Local Level Trend"},
     %{"t" => i, "value" => truth, "series" => "True Trend"}]
  end)

Vl.new(width: 700, height: 300, title: "Model 1: Local Level (constant noise assumed)")
|> Vl.data_from_values(ll_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)

Observation: The local level trend is too smooth during the quiet periods (because sigma_obs is inflated by the noisy middle section) and not smooth enough during the noisy period (because it’s using the same sigma everywhere). The model doesn’t know that volatility changed.

Model 2: Stochastic Volatility

Now let the observation noise vary over time. Two latent states: trend AND log-volatility.

trend_{t+1} = trend_t + η_t,             η ~ N(0, σ_trend²)
h_{t+1} = h_t + γ_t,                     γ ~ N(0, σ_h²)
y_t = trend_t + exp(h_t / 2) · ε_t
ir_sv = Builder.new_ir()
  |> Builder.data(y)
  |> Builder.rv("sigma_trend", HalfNormal, %{sigma: SSM.t(2.0)})
  |> Builder.rv("sigma_h", HalfNormal, %{sigma: SSM.t(1.0)})
  |> Builder.rv("trend", GaussianRandomWalk, %{sigma: "sigma_trend"}, shape: {n})
  |> Builder.rv("log_vol", GaussianRandomWalk, %{sigma: "sigma_h"}, shape: {n})

sv_logpdf = fn _x, params ->
  obs = params.__obs_data
  sigma_t = Nx.exp(Nx.divide(params.log_vol, SSM.t(2.0)))
  sigma_t = Nx.max(sigma_t, SSM.t(1.0e-6))
  SSM.normal_logpdf_vec(obs, params.trend, sigma_t)
end

dist_sv = Custom.new(sv_logpdf, support: :real)

ir_sv = Custom.rv(ir_sv, "lik", dist_sv, %{
    trend: "trend", log_vol: "log_vol", __obs_data: "__obs_data"
  })
  |> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))
{trace_sv, stats_sv} = Sampler.sample(ir_sv,
  %{"sigma_trend" => 0.3, "sigma_h" => 0.2,
    "trend" => y,
    "log_vol" => Nx.broadcast(Nx.tensor(0.0, type: :f64), {n})},
  num_warmup: 800, num_samples: 500)

IO.puts("Stochastic Volatility:")
IO.puts("  sigma_trend = #{Nx.mean(trace_sv["sigma_trend"]) |> Nx.to_number() |> Float.round(3)} (true: 0.3)")
IO.puts("  sigma_h = #{Nx.mean(trace_sv["sigma_h"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts("  divergences = #{stats_sv.divergences}")
trend_sv = trace_sv["trend"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()
log_vol_mean = trace_sv["log_vol"] |> Nx.mean(axes: [0]) |> Nx.to_flat_list()
vol_sv = Enum.map(log_vol_mean, fn h -> :math.exp(h / 2) end)

# Trend comparison
trend_plot = Enum.zip([1..n, y_list, trend_sv, true_trend])
  |> Enum.flat_map(fn {i, obs, est, truth} ->
    [%{"t" => i, "value" => obs, "series" => "Data"},
     %{"t" => i, "value" => est, "series" => "SV Trend"},
     %{"t" => i, "value" => truth, "series" => "True Trend"}]
  end)

Vl.new(width: 700, height: 300, title: "Model 2: Stochastic Volatility — Trend")
|> Vl.data_from_values(trend_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)
# Volatility recovery
vol_plot = Enum.zip([1..n, vol_sv, true_vol])
  |> Enum.flat_map(fn {i, est, truth} ->
    [%{"t" => i, "value" => est, "series" => "Estimated σ(t)"},
     %{"t" => i, "value" => truth, "series" => "True σ(t)"}]
  end)

Vl.new(width: 700, height: 200, title: "Model 2: Recovered Volatility exp(h_t/2)")
|> Vl.data_from_values(vol_plot)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "t", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative, title: "σ(t)")
|> Vl.encode_field(:color, "series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "series", type: :nominal)

Key result: The model discovers the volatility spike (t=50-80) without being told when it occurred. The SV trend is smoother during quiet periods and tracks the data more closely during noisy periods — because it knows the noise level at each time point.

Model 3: AR(1) Process

Autoregressive: the current value depends on the previous value with coefficient phi.

y_t = phi · y_{t-1} + ε_t,    ε ~ N(0, σ²)
# Generate AR(1) data
:rand.seed(:exsss, {100, 101, 102})
true_phi = 0.8
true_sigma_ar = 1.0
ar_data = Enum.scan(1..200, 0.0, fn _, prev ->
  true_phi * prev + :rand.normal() * true_sigma_ar
end)
y_ar = Nx.tensor(ar_data, type: :f64)

IO.puts("AR(1) data: phi=#{true_phi}, sigma=#{true_sigma_ar}, n=#{length(ar_data)}")
ir_ar = Builder.new_ir()
  |> Builder.data(y_ar)
  |> Builder.rv("phi", Normal, %{mu: SSM.t(0.0), sigma: SSM.t(1.0)})
  |> Builder.rv("sigma", HalfNormal, %{sigma: SSM.t(2.0)})

ar1_logpdf = fn _x, params ->
  obs = params.__obs_data
  phi = params.phi
  sigma = Nx.max(params.sigma, SSM.t(1.0e-6))
  n_obs = Nx.axis_size(obs, 0)

  y_prev = obs[0..(n_obs - 2)]
  y_curr = obs[1..(n_obs - 1)]
  mu = Nx.multiply(phi, y_prev)
  SSM.normal_logpdf_vec(y_curr, mu, sigma)
end

dist_ar = Custom.new(ar1_logpdf, support: :real)

ir_ar = Custom.rv(ir_ar, "lik", dist_ar, %{
    phi: "phi", sigma: "sigma", __obs_data: "__obs_data"
  })
  |> Builder.obs("lik_obs", "lik", Nx.tensor(0.0, type: :f64))

{trace_ar, stats_ar} = Sampler.sample(ir_ar,
  %{"phi" => 0.5, "sigma" => 1.0},
  num_warmup: 500, num_samples: 500)

phi_mean = Nx.mean(trace_ar["phi"]) |> Nx.to_number()
sigma_mean = Nx.mean(trace_ar["sigma"]) |> Nx.to_number()
IO.puts("AR(1) posterior:")
IO.puts("  phi = #{Float.round(phi_mean, 3)} (true: #{true_phi})")
IO.puts("  sigma = #{Float.round(sigma_mean, 3)} (true: #{true_sigma_ar})")
IO.puts("  divergences = #{stats_ar.divergences}")

Note: The AR(1) model has only 2 free parameters (phi, sigma) — no latent state vector. NUTS samples this in under a second. The posterior of phi tells you how persistent the process is.

Comparing the Models

IO.puts("=== Model Comparison ===")
IO.puts("")
IO.puts("Model 1 (Local Level):")
IO.puts("  Parameters: sigma_trend + sigma_obs + trend[1:#{n}] = #{n + 2}")
IO.puts("  sigma_trend = #{Nx.mean(trace_ll["sigma_trend"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts("  sigma_obs   = #{Nx.mean(trace_ll["sigma_obs"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts("  div = #{stats_ll.divergences}")
IO.puts("")
IO.puts("Model 2 (Stochastic Volatility):")
IO.puts("  Parameters: sigma_trend + sigma_h + trend[1:#{n}] + log_vol[1:#{n}] = #{2*n + 2}")
IO.puts("  sigma_trend = #{Nx.mean(trace_sv["sigma_trend"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts("  sigma_h     = #{Nx.mean(trace_sv["sigma_h"]) |> Nx.to_number() |> Float.round(3)}")
IO.puts("  div = #{stats_sv.divergences}")
IO.puts("")
IO.puts("Model 3 (AR(1)):")
IO.puts("  Parameters: phi + sigma = 2")
IO.puts("  phi   = #{phi_mean |> Float.round(3)}")
IO.puts("  sigma = #{sigma_mean |> Float.round(3)}")
IO.puts("  div = #{stats_ar.divergences}")

Summary

Model State equation When to use
Local level x_{t+1} = x_t + η Smooth trend extraction
Stochastic vol x_t + h_t (two GRWs) Financial data, crisis detection
AR(1) y_t = φ·y_{t-1} + ε Persistent but stationary processes
Regime-switching Mixture of Normals Market regimes, structural breaks

The eXMC pattern is always the same: Builder.data for observations, GaussianRandomWalk for latent states, Custom for the likelihood, Builder.obs to mark it observed. NUTS handles the rest.

References

  • Harvey, A.C. (1989). Forecasting, Structural Time Series Models and the Kalman Filter. Cambridge University Press.
  • Kim, Shephard & Chib (1998). “Stochastic Volatility.” Review of Economic Studies.
  • Durbin & Koopman (2012). Time Series Analysis by State Space Methods. OUP.