Vulkan Compute on the BEAM — Chain Shaders and the GPU Process
Mix.install([
{:exmc, path: Path.expand("../", __DIR__)},
{:nx, "~> 0.10"},
{:nx_vulkan,
github: "borodark/nx_vulkan",
branch: "main",
optional: true},
{:kino, "~> 0.14"}
])
What this notebook is
The companion to 16_vulkan_demo.livemd. That one shows whether
Vulkan-on-BEAM works (it does — eXMC’s NUTS sampler runs identical
posteriors under EXLA and Vulkan). This one shows how the
machinery is built:
-
Nx.Vulkan.Nodeis a long-livedGenServer. It owns thevkPipelineCache, the persistent buffer registry, and a watchdog. Every Vulkan op the BEAM dispatches goes through it. -
Chain-shader specs are rendered to GLSL at runtime, validated
with
glslangValidator, and content-addressed under the cache. - A single dispatch executes K leapfrog steps inside one GPU command buffer, sidestepping the per-dispatch overhead that would otherwise dominate.
The order matters: the process exists so the pipeline cache has a home, so shader synthesis pays off across calls. If you remove the GenServer, every dispatch re-creates the device’s pipeline state.
We learned this the expensive way: a paper-trading deploy ran for
nine hours with compiler=Nx.Vulkan and default_backend=Nx.Vulkan.Backend
but no Nx.Vulkan.Node in the supervision tree. The
Nx.Vulkan.Backend happily routed ops to a GenServer that didn’t
exist; NUTS calls hung; a 300-second watchdog inside eXMC reset
each instrument in turn. The trader looked alive — broker polls
kept working — and produced zero posterior updates. That’s the
load-bearing GenServer.
Part 1 — The GPU process
Nx.Vulkan.Node is a normal OTP GenServer. You start it like any
other:
unless Code.ensure_loaded?(Nx.Vulkan.Node) and function_exported?(Nx.Vulkan.Node, :start_link, 1) do
raise "nx_vulkan not loaded — see the Vulkan detection cell in 16_vulkan_demo.livemd"
end
{:ok, vnode} =
case Process.whereis(Nx.Vulkan.Node) do
nil -> Nx.Vulkan.Node.start_link()
pid -> {:ok, pid}
end
%{
pid: vnode,
device: Nx.Vulkan.Node.with_node(fn -> Nx.Vulkan.Native.device_name() end),
f64?: Nx.Vulkan.Node.with_node(fn -> Nx.Vulkan.Native.has_f64() end),
alive?: Nx.Vulkan.Node.alive?()
}
For a real application, add it to your Application.start/2
children, not start it ad-hoc. Otherwise you depend on luck:
something is supposed to start it before you dispatch, and if
that something silently drops out of the supervision tree, you’re
in the nine-hour-zombie state above.
# In your own Application module (paste, do not run here):
#
# def start(_type, _args) do
# children =
# gpu_children() ++ trading_children()
#
# Supervisor.start_link(children, strategy: :one_for_one, name: MyApp.Sup)
# end
#
# defp gpu_children do
# case Exmc.JIT.detect_compiler() do
# Nx.Vulkan -> [Nx.Vulkan.Node]
# _ -> []
# end
# end
The guard matters: on a host where the user has picked EXLA
(Linux + CUDA) or EMLX (macOS), Nx.Vulkan.Node shouldn’t be
running. The detection happens once, at boot.
Every Vulkan op goes through with_node/2. It’s a synchronous
GenServer.call — work is serialised through the GenServer’s
mailbox. Concurrency lives between dispatches (your sampler
chains, your instrument workers), not inside a single dispatch.
%{
status: Nx.Vulkan.Node.status(),
queue_len:
Process.info(vnode, :message_queue_len)
|> elem(1)
}
Part 2 — A chain shader from a spec
A FamilySpec is a small declarative record describing one
distribution family. The Beta family is in the standard catalog:
spec = Nx.Vulkan.ChainShaderSpecs.beta()
spec
Synthesis.compile/1 renders the spec to GLSL via a templated
leapfrog skeleton, runs it through glslangValidator, and writes
the resulting SPIR-V to a content-addressed cache file:
{cold_us, {:ok, spv_path}} =
:timer.tc(fn -> Nx.Vulkan.Synthesis.compile(spec) end)
%{cold_us: cold_us, spv_path: spv_path, size_bytes: File.stat!(spv_path).size}
Compile the same spec again. The cache lookup is content-addressed (SHA-256 of the rendered GLSL), so identical specs hit the same cache entry without re-running glslang:
{warm_us, {:ok, ^spv_path}} =
:timer.tc(fn -> Nx.Vulkan.Synthesis.compile(spec) end)
%{warm_us: warm_us, speedup: Float.round(cold_us / max(warm_us, 1), 1)}
The cold path is bounded by glslangValidator invocation time;
the warm path is bounded by a single file-existence check. The
difference is usually two orders of magnitude.
compile_with_source/1 is compile/1 plus the rendered GLSL,
which is handy when you want to read what the template produced:
{:ok, _spv_path, glsl} = Nx.Vulkan.Synthesis.compile_with_source(spec)
Kino.Markdown.new("""
#{String.slice(glsl, 0, 1200)} …
""")
What’s in there:
-
Push-constant block with
n(chain count),K(leapfrog steps),eps(step size), and family-specific parameters (alpha,beta, …). - Six read/write SSBOs for input/output positions, momenta, gradients, and per-step log-probabilities.
-
One
for k = 0..K-1loop performing K leapfrog steps before returning to the host. This is the batched leapfrog — one dispatch, K updates.
The template parameterises the leapfrog body on grad_log_p(q)
and log_p(q) expressions specific to the family. Beta plugs in
α/(α+β·sigmoid(q)) - … etc.; Gamma swaps the gradient. The
mechanics outside the per-family body are identical across specs.
Part 3 — Dispatching a chain
The push-constant helper computes the buffer-layout-correct byte sequence for a single Beta family dispatch:
alpha = 2.0
beta = 5.0
n = 1
k = 32
eps = 0.05
# `logp_const` is the family's log-normalising constant (
# logBeta(α,β) for Beta). Precomputed and passed as a uniform.
logp_const = :math.lgamma(alpha) + :math.lgamma(beta) - :math.lgamma(alpha + beta)
push = Nx.Vulkan.ChainShaderSpecs.beta_push(n, k, eps, alpha, beta, logp_const)
byte_size(push)
In a complete dispatch you’d:
-
Upload
q_init,p_init,inv_massSSBOs. -
Call
Nx.Vulkan.Native.leapfrog_chain_synth/6insidewith_node/1. -
Download
q_chain,p_chain,grad_chain,logp_chain.
For brevity here, the exact NIF call shape lives in
Nx.Vulkan.Native.leapfrog_chain_synth/6 — see the module docs.
What matters at the architectural level: one GenServer call
covers K leapfrog steps. The per-call cost is amortised across K,
which is why Vulkan-on-BEAM crosses break-even versus the
per-step IR walker.
Part 4 — What happens when things break
Three failure modes worth understanding before deploying.
1. A broken spec. Pass a FamilySpec whose GLSL template
references a function that doesn’t exist:
broken_spec = %{spec | name: "broken"}
# Pretend we mangled the body somehow. In practice this happens
# when a new distribution's template hits an unimplemented op.
# Synthesis.compile/1 surfaces a structured error tuple.
case Nx.Vulkan.Synthesis.compile(broken_spec) do
{:ok, _spv} -> :ok_unmangled
{:error, reason} -> {:err, reason}
end
(The exact failure shape depends on which template line is
broken; glslangValidator‘s stderr ends up in the error term.)
2. A runtime watchdog timeout. Nx.Vulkan.Node.with_node/2
has a configurable watchdog. If a single dispatch hangs (driver
bug, device hang, infinite loop in a shader), the watchdog kills
the GenServer call and you get {:error, :node_timeout}. eXMC’s
NUTS sampler catches this and falls back to its EXLA path. Your
own code should treat it the same way: catch, log, fall back or
retry.
3. The Node didn’t start. What it looks like at runtime:
# Don't try this here — but conceptually:
# Process.exit(vnode, :kill)
# Nx.Vulkan.Backend.add(...) # → GenServer.call(nil, ...) — argument error
If your Application.start doesn’t include Nx.Vulkan.Node,
your Vulkan-backed ops eventually call into a nil PID. The
error is not “Vulkan unavailable” — it’s **(ArgumentError) GenServer call with nil, deep in the Nx call stack. The fix is
the supervision-tree integration shown in Part 1.
Wrap-up
The GPU process is the API you actually program against. The
shader cache is the API you tune. The two are coupled: a long-
lived GenServer is what lets the pipeline cache amortise the
SPIR-V compile + pipeline build across thousands of dispatches.
A short-lived script that starts and stops Nx.Vulkan.Node per
call would build and discard the pipeline cache every time —
correctness is fine, throughput is wrecked.
If you take one thing away: put Nx.Vulkan.Node in your
supervision tree at boot, behind a compiler-detect guard.
%{vnode: Process.whereis(Nx.Vulkan.Node), spv: spv_path, glsl_bytes: byte_size(glsl)}