Powered by AppSignal & Oban Pro

Vulkan Backend Demo — eXMC NUTS without CUDA

notebooks/16_vulkan_demo.livemd

Vulkan Backend Demo — eXMC NUTS without CUDA

Mix.install([
  {:exmc, path: Path.expand("../", __DIR__)},
  {:nx, "~> 0.10"},
  # Match exmc's mix.exs: same source, same branch.
  # For local nx_vulkan iteration, swap both this AND exmc/mix.exs to
  # path: deps pointing at your checkout.
  {:nx_vulkan,
   github: "borodark/nx_vulkan",
   branch: "main",
   optional: true},
  {:kino, "~> 0.14"},
  {:kino_vega_lite, "~> 0.1"}
])

What this notebook shows

eXMC has historically depended on EXLA for fast NUTS sampling on the GPU — and EXLA only runs on Linux with NVIDIA CUDA. Nx.Vulkan is a third tensor backend that runs Vulkan compute shaders on any GPU that has a Vulkan driver: Linux NVIDIA, Linux AMD, Linux Intel, FreeBSD with NVIDIA, macOS via MoltenVK.

This notebook does three things:

  1. Detects whether nx_vulkan is loaded and what GPU it sees.
  2. Samples a trivial Normal-Normal posterior under three backends (EXLA reference, Vulkan with the per-step IR walker, Vulkan with the fused-leapfrog chain shader) and compares posteriors
    • wall-clock.
  3. Visualises the posterior densities side by side.

The Vulkan path uses one fused shader that performs K=32 leapfrog steps in a single GPU dispatch, sidestepping the ~500 µs per-dispatch overhead that previously made EXMC_COMPILER=vulkan impractical for NUTS.

Vulkan detection

vulkan_available? =
  Code.ensure_loaded?(Nx.Vulkan) and
    function_exported?(Nx.Vulkan, :init, 0)

if vulkan_available? do
  Nx.Vulkan.init()

  device = Nx.Vulkan.Native.device_name()
  has_f64 = Nx.Vulkan.Native.has_f64()

  Kino.Markdown.new("""
  **Vulkan**: ✅ available
  - Device: `#{device}`
  - f64 supported: `#{has_f64}`
  """)
else
  Kino.Markdown.new("""
  **Vulkan**: ❌ not available

  This notebook needs `nx_vulkan` to compare backends. Either it
  isn't on the load path or the NIF didn't build. Without it, only
  the EXLA cell below is runnable. To enable: ensure `nx_vulkan`
  is buildable in the workspace (Linux/FreeBSD with the Vulkan
  loader; macOS via MoltenVK).
  """)
end

The model

The simplest non-trivial Bayesian model: x ~ N(0, 1) with no observations. The posterior IS the prior. We expect mean ≈ 0 and variance ≈ 1 from any correct sampler.

alias Exmc.{Builder, Dist.Normal, NUTS.Sampler}

build_ir = fn ->
  Builder.new_ir()
  |> Builder.rv("x", Normal, %{mu: Nx.tensor(0.0), sigma: Nx.tensor(1.0)})
end

run = fn ir, label ->
  t0 = System.monotonic_time(:millisecond)
  {trace, _stats} = Sampler.sample(ir, %{}, num_warmup: 200, num_samples: 1000, seed: 42)
  t1 = System.monotonic_time(:millisecond)

  xs = trace["x"] |> Nx.to_flat_list()
  mean = Enum.sum(xs) / length(xs)
  var = Enum.sum(Enum.map(xs, fn x -> (x - mean) * (x - mean) end)) / length(xs)
  {label, t1 - t0, mean, var, xs}
end

EXLA reference (f64)

The fast path on Linux NVIDIA. We use it as the reference for posterior recovery — anything else should match within MCMC noise.

exla_result = run.(build_ir.(), "EXLA")

{label, ms, mean, var, _xs} = exla_result

Kino.Markdown.new("""
**#{label}**: wall = #{ms} ms · posterior mean = `#{Float.round(mean, 4)}` ·
posterior variance = `#{Float.round(var, 4)}`
""")

Vulkan, per-step IR walker (the slow path)

Selects Nx.Vulkan as the JIT compiler, but doesn’t enable the fused-leapfrog chain shader. Each leapfrog step expands to ~12 elementwise dispatches via the IR walker, with ~500 µs of per-dispatch overhead. This will take many minutes for even a tiny model — that’s the structural problem the chain shader fixes.

> Skip this cell unless you want to confirm the cost first-hand.

vulkan_unfused_result =
  if vulkan_available? do
    Application.put_env(:exmc, :compiler, :vulkan)
    Application.delete_env(:exmc, :fused_leapfrog_normal_meta)
    run.(build_ir.(), "Vulkan (unfused — slow)")
  else
    nil
  end

vulkan_unfused_result
|> case do
  nil ->
    Kino.Markdown.new("Vulkan not available — skipped.")

  {label, ms, mean, var, _xs} ->
    Kino.Markdown.new("""
    **#{label}**: wall = #{ms} ms · mean = `#{Float.round(mean, 4)}` ·
    var = `#{Float.round(var, 4)}`
    """)
end

Vulkan, fused leapfrog chain (the win)

Same Vulkan compiler, but fused_leapfrog_normal_meta is set, which routes the eXMC speculative-precomputation path through Nx.Vulkan.leapfrog_chain_normal/7 — one shader dispatch per K=32 leapfrog steps. Per-step amortized cost drops from ~6000 µs to ~50 µs.

vulkan_fused_result =
  if vulkan_available? do
    Application.put_env(:exmc, :compiler, :vulkan)
    # The (mu, sigma) tuple opts the speculative path into the chain shader.
    # Generalising this to per-IR detection (so any single Normal RV
    # auto-routes) is a follow-up task.
    Application.put_env(:exmc, :fused_leapfrog_normal_meta, {0.0, 1.0})
    res = run.(build_ir.(), "Vulkan (fused chain)")
    Application.delete_env(:exmc, :fused_leapfrog_normal_meta)
    res
  end

vulkan_fused_result
|> case do
  nil ->
    Kino.Markdown.new("Vulkan not available — skipped.")

  {label, ms, mean, var, _xs} ->
    Kino.Markdown.new("""
    **#{label}**: wall = #{ms} ms · mean = `#{Float.round(mean, 4)}` ·
    var = `#{Float.round(var, 4)}`
    """)
end

Side-by-side comparison

results =
  [exla_result, vulkan_unfused_result, vulkan_fused_result]
  |> Enum.reject(&is_nil/1)

table =
  Enum.map_join(results, "\n", fn {label, ms, mean, var, _xs} ->
    "| #{label} | #{ms} | #{Float.round(mean, 4)} | #{Float.round(var, 4)} |"
  end)

Kino.Markdown.new("""
| backend | wall (ms) | mean | var |
|---------|-----------|------|-----|
#{table}

Reference: posterior of `x ~ N(0, 1)` should give `mean ≈ 0`,
`var ≈ 1`. All three backends should agree on the posterior
within MCMC noise; what differs is wall-clock.
""")

Posterior densities (visual check)

Overlay the empirical density of each backend’s draws.

alias VegaLite, as: Vl

points =
  for {label, _ms, _mean, _var, xs} <- results,
      x <- xs,
      do: %{"backend" => label, "x" => x}

Vl.new(width: 700, height: 360, title: "Posterior of x ~ N(0, 1)")
|> Vl.data_from_values(points)
|> Vl.mark(:line)
|> Vl.transform(density: "x", groupby: ["backend"], extent: [-4, 4])
|> Vl.encode_field(:x, "value", type: :quantitative, axis: [title: "x"])
|> Vl.encode_field(:y, "density",
  type: :quantitative,
  axis: [title: "density"]
)
|> Vl.encode_field(:color, "backend", type: :nominal)

What this notebook does NOT show

  • Phase 2 distributions: the chain pattern works for Normal today; Student-t, Cauchy, HalfNormal, Exponential, and an f64 Normal sibling all have shaders shipped in nx_vulkan. eXMC’s do_dispatch only auto-routes Normal as of this notebook’s date — generalising the dispatch detection per distribution is the natural follow-up.
  • Hierarchical models: fused_leapfrog_normal_meta assumes scalar (mu, sigma). Hierarchical models with multiple RVs need per-IR analysis to identify which RVs are eligible for which chain shader.
  • NUTS-heavy benchmark models: the Radon, Reliability, Insurance, etc. notebooks elsewhere in this directory will benefit from the fused chain once their distributions get the dispatch routing — but they currently fall through to the unfused path.

Cross-references

  • ~/projects/learn_erl/nx_vulkan/PLAN_FUSED_LEAPFROG.md — full plan and implementation arc
  • ~/projects/learn_erl/nx_vulkan/RESEARCH_FAST_KERNELS.md — research note on per-dispatch overhead and break-even
  • docs/VULKAN_KNOWN_ISSUES.md (in this repo) — tracked Vulkan-specific failures and their status
  • test/nuts/fused_chain_diag_test.exs (in this repo) — the ExUnit test that pins the variance-comparison baseline
  • The GPU That Doesn’t Need CUDA
  • A Walkable Path Under the Mountain