Getting Started with eXMC
Section
# EMLX (Metal GPU) on macOS Apple Silicon, EXLA (CPU/CUDA) elsewhere
backend_dep =
case :os.type() do
{:unix, :darwin} -> {:emlx, "~> 0.2"}
_ -> {:exla, "~> 0.10"}
end
Mix.install([
{:exmc, path: Path.expand("../", __DIR__)},
backend_dep,
{:kino_vega_lite, "~> 0.1"}
])
Building a Simple Model
eXMC uses an IR (intermediate representation) to define probabilistic models.
Start with Builder.new_ir() and add random variables and observations.
alias Exmc.{Builder, Sampler}
# Build a model: mu ~ Normal(0, 10), observe data at known sigma
ir = Builder.new_ir()
ir = Builder.rv(ir, "mu", Exmc.Dist.Normal, %{mu: Nx.tensor(0.0), sigma: Nx.tensor(10.0)})
ir = Builder.rv(ir, "y", Exmc.Dist.Normal, %{mu: "mu", sigma: Nx.tensor(1.0)})
ir = Builder.obs(ir, "y_obs", "y", Nx.tensor([2.1, 2.5, 1.8, 2.3, 2.7]))
ir
Running NUTS Sampling
The Sampler.sample/3 function runs the No-U-Turn Sampler (NUTS):
{trace, stats} = Sampler.sample(ir, %{"mu" => 2.0}, num_samples: 500, seed: 42, num_warmup: 200)
IO.inspect(Map.keys(trace), label: "Variables")
IO.inspect(Nx.shape(trace["mu"]), label: "Shape")
Diagnostics
Check convergence with Diagnostics.summary/1:
summary = Exmc.Diagnostics.summary(trace)
for {name, stats} <- summary do
IO.puts("#{name}: mean=#{Float.round(stats.mean, 3)} std=#{Float.round(stats.std, 3)}")
end
Visualization with VegaLite
Plot the posterior trace and histogram:
alias VegaLite, as: Vl
mu_samples = Nx.to_flat_list(trace["mu"]) |> Enum.filter(&is_number/1)
data = Enum.with_index(mu_samples, fn val, i -> %{"iteration" => i, "mu" => val} end)
Vl.new(width: 600, height: 200, title: "Trace: mu")
|> Vl.data_from_values(data)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "iteration", type: :quantitative)
|> Vl.encode_field(:y, "mu", type: :quantitative)
Vl.new(width: 600, height: 200, title: "Posterior: mu")
|> Vl.data_from_values(data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "mu", type: :quantitative, bin: [maxbins: 30])
|> Vl.encode(:y, aggregate: :count)
A Model with Free Parameters
When multiple parameters are free, NUTS explores the joint posterior:
ir2 = Builder.new_ir()
ir2 = Builder.rv(ir2, "mu", Exmc.Dist.Normal, %{mu: Nx.tensor(0.0), sigma: Nx.tensor(10.0)})
ir2 = Builder.rv(ir2, "sigma", Exmc.Dist.HalfNormal, %{sigma: Nx.tensor(2.0)}, transform: :log)
ir2 = Builder.rv(ir2, "y", Exmc.Dist.Normal, %{mu: "mu", sigma: "sigma"})
ir2 = Builder.obs(ir2, "y_obs", "y", Nx.tensor([2.1, 2.5, 1.8, 2.3, 2.7]))
{trace2, stats2} =
Sampler.sample(ir2, %{"mu" => 2.0, "sigma" => 1.0},
num_samples: 500,
seed: 123,
num_warmup: 300
)
summary2 = Exmc.Diagnostics.summary(trace2)
for {name, s} <- summary2 do
IO.puts("#{name}: mean=#{Float.round(s.mean, 3)} std=#{Float.round(s.std, 3)}")
end
mu2 = Nx.to_flat_list(trace2["mu"]) |> Enum.filter(&is_number/1)
sigma2 = Nx.to_flat_list(trace2["sigma"]) |> Enum.filter(&is_number/1)
scatter_data =
Enum.zip(mu2, sigma2)
|> Enum.map(fn {m, s} -> %{"mu" => m, "sigma" => s} end)
Vl.new(width: 400, height: 400, title: "Joint posterior: mu vs sigma")
|> Vl.data_from_values(scatter_data)
|> Vl.mark(:point, opacity: 0.3, size: 10)
|> Vl.encode_field(:x, "mu", type: :quantitative)
|> Vl.encode_field(:y, "sigma", type: :quantitative)