Vulkan Backend Demo — eXMC NUTS without CUDA
Mix.install([
{:exmc, path: Path.expand("../", __DIR__)},
{:nx, "~> 0.10"},
# Match exmc's mix.exs: same source, same branch.
# For local nx_vulkan iteration, swap both this AND exmc/mix.exs to
# path: deps pointing at your checkout.
{:nx_vulkan,
github: "borodark/nx_vulkan",
branch: "main",
optional: true},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"}
])
What this notebook shows
eXMC has historically depended on EXLA for fast NUTS sampling on the GPU — and EXLA only runs on Linux with NVIDIA CUDA. Nx.Vulkan is a third tensor backend that runs Vulkan compute shaders on any GPU that has a Vulkan driver: Linux NVIDIA, Linux AMD, Linux Intel, FreeBSD with NVIDIA, macOS via MoltenVK.
This notebook does three things:
-
Detects whether
nx_vulkanis loaded and what GPU it sees. -
Samples a trivial Normal-Normal posterior under three backends
(EXLA reference, Vulkan with the per-step IR walker, Vulkan
with the fused-leapfrog chain shader) and compares posteriors
- wall-clock.
- Visualises the posterior densities side by side.
The Vulkan path uses one fused shader that performs K=32
leapfrog steps in a single GPU dispatch, sidestepping the
~500 µs per-dispatch overhead that previously made
EXMC_COMPILER=vulkan impractical for NUTS.
Vulkan detection
vulkan_available? =
Code.ensure_loaded?(Nx.Vulkan) and
function_exported?(Nx.Vulkan, :init, 0)
if vulkan_available? do
Nx.Vulkan.init()
device = Nx.Vulkan.Native.device_name()
has_f64 = Nx.Vulkan.Native.has_f64()
Kino.Markdown.new("""
**Vulkan**: ✅ available
- Device: `#{device}`
- f64 supported: `#{has_f64}`
""")
else
Kino.Markdown.new("""
**Vulkan**: ❌ not available
This notebook needs `nx_vulkan` to compare backends. Either it
isn't on the load path or the NIF didn't build. Without it, only
the EXLA cell below is runnable. To enable: ensure `nx_vulkan`
is buildable in the workspace (Linux/FreeBSD with the Vulkan
loader; macOS via MoltenVK).
""")
end
The model
The simplest non-trivial Bayesian model: x ~ N(0, 1) with no
observations. The posterior IS the prior. We expect mean ≈ 0 and
variance ≈ 1 from any correct sampler.
alias Exmc.{Builder, Dist.Normal, NUTS.Sampler}
build_ir = fn ->
Builder.new_ir()
|> Builder.rv("x", Normal, %{mu: Nx.tensor(0.0), sigma: Nx.tensor(1.0)})
end
run = fn ir, label ->
t0 = System.monotonic_time(:millisecond)
{trace, _stats} = Sampler.sample(ir, %{}, num_warmup: 200, num_samples: 1000, seed: 42)
t1 = System.monotonic_time(:millisecond)
xs = trace["x"] |> Nx.to_flat_list()
mean = Enum.sum(xs) / length(xs)
var = Enum.sum(Enum.map(xs, fn x -> (x - mean) * (x - mean) end)) / length(xs)
{label, t1 - t0, mean, var, xs}
end
EXLA reference (f64)
The fast path on Linux NVIDIA. We use it as the reference for posterior recovery — anything else should match within MCMC noise.
exla_result = run.(build_ir.(), "EXLA")
{label, ms, mean, var, _xs} = exla_result
Kino.Markdown.new("""
**#{label}**: wall = #{ms} ms · posterior mean = `#{Float.round(mean, 4)}` ·
posterior variance = `#{Float.round(var, 4)}`
""")
Vulkan, per-step IR walker (the slow path)
Selects Nx.Vulkan as the JIT compiler, but doesn’t enable the
fused-leapfrog chain shader. Each leapfrog step expands to ~12
elementwise dispatches via the IR walker, with ~500 µs of
per-dispatch overhead. This will take many minutes for even
a tiny model — that’s the structural problem the chain shader
fixes.
> Skip this cell unless you want to confirm the cost first-hand.
vulkan_unfused_result =
if vulkan_available? do
Application.put_env(:exmc, :compiler, :vulkan)
Application.delete_env(:exmc, :fused_leapfrog_normal_meta)
run.(build_ir.(), "Vulkan (unfused — slow)")
else
nil
end
vulkan_unfused_result
|> case do
nil ->
Kino.Markdown.new("Vulkan not available — skipped.")
{label, ms, mean, var, _xs} ->
Kino.Markdown.new("""
**#{label}**: wall = #{ms} ms · mean = `#{Float.round(mean, 4)}` ·
var = `#{Float.round(var, 4)}`
""")
end
Vulkan, fused leapfrog chain (the win)
Same Vulkan compiler, but fused_leapfrog_normal_meta is set,
which routes the eXMC speculative-precomputation path through
Nx.Vulkan.leapfrog_chain_normal/7 — one shader dispatch
per K=32 leapfrog steps. Per-step amortized cost drops from
~6000 µs to ~50 µs.
vulkan_fused_result =
if vulkan_available? do
Application.put_env(:exmc, :compiler, :vulkan)
# The (mu, sigma) tuple opts the speculative path into the chain shader.
# Generalising this to per-IR detection (so any single Normal RV
# auto-routes) is a follow-up task.
Application.put_env(:exmc, :fused_leapfrog_normal_meta, {0.0, 1.0})
res = run.(build_ir.(), "Vulkan (fused chain)")
Application.delete_env(:exmc, :fused_leapfrog_normal_meta)
res
end
vulkan_fused_result
|> case do
nil ->
Kino.Markdown.new("Vulkan not available — skipped.")
{label, ms, mean, var, _xs} ->
Kino.Markdown.new("""
**#{label}**: wall = #{ms} ms · mean = `#{Float.round(mean, 4)}` ·
var = `#{Float.round(var, 4)}`
""")
end
Side-by-side comparison
results =
[exla_result, vulkan_unfused_result, vulkan_fused_result]
|> Enum.reject(&is_nil/1)
table =
Enum.map_join(results, "\n", fn {label, ms, mean, var, _xs} ->
"| #{label} | #{ms} | #{Float.round(mean, 4)} | #{Float.round(var, 4)} |"
end)
Kino.Markdown.new("""
| backend | wall (ms) | mean | var |
|---------|-----------|------|-----|
#{table}
Reference: posterior of `x ~ N(0, 1)` should give `mean ≈ 0`,
`var ≈ 1`. All three backends should agree on the posterior
within MCMC noise; what differs is wall-clock.
""")
Posterior densities (visual check)
Overlay the empirical density of each backend’s draws.
alias VegaLite, as: Vl
points =
for {label, _ms, _mean, _var, xs} <- results,
x <- xs,
do: %{"backend" => label, "x" => x}
Vl.new(width: 700, height: 360, title: "Posterior of x ~ N(0, 1)")
|> Vl.data_from_values(points)
|> Vl.mark(:line)
|> Vl.transform(density: "x", groupby: ["backend"], extent: [-4, 4])
|> Vl.encode_field(:x, "value", type: :quantitative, axis: [title: "x"])
|> Vl.encode_field(:y, "density",
type: :quantitative,
axis: [title: "density"]
)
|> Vl.encode_field(:color, "backend", type: :nominal)
What this notebook does NOT show
-
Phase 2 distributions: the chain pattern works for Normal
today; Student-t, Cauchy, HalfNormal, Exponential, and an f64
Normal sibling all have shaders shipped in
nx_vulkan. eXMC’sdo_dispatchonly auto-routes Normal as of this notebook’s date — generalising the dispatch detection per distribution is the natural follow-up. -
Hierarchical models:
fused_leapfrog_normal_metaassumes scalar(mu, sigma). Hierarchical models with multiple RVs need per-IR analysis to identify which RVs are eligible for which chain shader. - NUTS-heavy benchmark models: the Radon, Reliability, Insurance, etc. notebooks elsewhere in this directory will benefit from the fused chain once their distributions get the dispatch routing — but they currently fall through to the unfused path.
Cross-references
-
~/projects/learn_erl/nx_vulkan/PLAN_FUSED_LEAPFROG.md— full plan and implementation arc -
~/projects/learn_erl/nx_vulkan/RESEARCH_FAST_KERNELS.md— research note on per-dispatch overhead and break-even -
docs/VULKAN_KNOWN_ISSUES.md(in this repo) — tracked Vulkan-specific failures and their status -
test/nuts/fused_chain_diag_test.exs(in this repo) — the ExUnit test that pins the variance-comparison baseline - The GPU That Doesn’t Need CUDA
- A Walkable Path Under the Mountain