Powered by AppSignal & Oban Pro

Vulkan Compute on the BEAM — Chain Shaders and the GPU Process

17_vulkan_chains_and_gpu_process.livemd

Vulkan Compute on the BEAM — Chain Shaders and the GPU Process

Mix.install([
  {:exmc, path: Path.expand("../", __DIR__)},
  {:nx, "~> 0.10"},
  {:nx_vulkan,
   github: "borodark/nx_vulkan",
   branch: "main",
   optional: true},
  {:kino, "~> 0.14"}
])

What this notebook is

The companion to 16_vulkan_demo.livemd. That one shows whether Vulkan-on-BEAM works (it does — eXMC’s NUTS sampler runs identical posteriors under EXLA and Vulkan). This one shows how the machinery is built:

  1. Nx.Vulkan.Node is a long-lived GenServer. It owns the vkPipelineCache, the persistent buffer registry, and a watchdog. Every Vulkan op the BEAM dispatches goes through it.
  2. Chain-shader specs are rendered to GLSL at runtime, validated with glslangValidator, and content-addressed under the cache.
  3. A single dispatch executes K leapfrog steps inside one GPU command buffer, sidestepping the per-dispatch overhead that would otherwise dominate.

The order matters: the process exists so the pipeline cache has a home, so shader synthesis pays off across calls. If you remove the GenServer, every dispatch re-creates the device’s pipeline state.

We learned this the expensive way: a paper-trading deploy ran for nine hours with compiler=Nx.Vulkan and default_backend=Nx.Vulkan.Backend but no Nx.Vulkan.Node in the supervision tree. The Nx.Vulkan.Backend happily routed ops to a GenServer that didn’t exist; NUTS calls hung; a 300-second watchdog inside eXMC reset each instrument in turn. The trader looked alive — broker polls kept working — and produced zero posterior updates. That’s the load-bearing GenServer.

Part 1 — The GPU process

Nx.Vulkan.Node is a normal OTP GenServer. You start it like any other:

unless Code.ensure_loaded?(Nx.Vulkan.Node) and function_exported?(Nx.Vulkan.Node, :start_link, 1) do
  raise "nx_vulkan not loaded — see the Vulkan detection cell in 16_vulkan_demo.livemd"
end

{:ok, vnode} =
  case Process.whereis(Nx.Vulkan.Node) do
    nil -> Nx.Vulkan.Node.start_link()
    pid -> {:ok, pid}
  end

%{
  pid: vnode,
  device: Nx.Vulkan.Node.with_node(fn -> Nx.Vulkan.Native.device_name() end),
  f64?: Nx.Vulkan.Node.with_node(fn -> Nx.Vulkan.Native.has_f64() end),
  alive?: Nx.Vulkan.Node.alive?()
}

For a real application, add it to your Application.start/2 children, not start it ad-hoc. Otherwise you depend on luck: something is supposed to start it before you dispatch, and if that something silently drops out of the supervision tree, you’re in the nine-hour-zombie state above.

# In your own Application module (paste, do not run here):
#
#   def start(_type, _args) do
#     children =
#       gpu_children() ++ trading_children()
#
#     Supervisor.start_link(children, strategy: :one_for_one, name: MyApp.Sup)
#   end
#
#   defp gpu_children do
#     case Exmc.JIT.detect_compiler() do
#       Nx.Vulkan -> [Nx.Vulkan.Node]
#       _ -> []
#     end
#   end

The guard matters: on a host where the user has picked EXLA (Linux + CUDA) or EMLX (macOS), Nx.Vulkan.Node shouldn’t be running. The detection happens once, at boot.

Every Vulkan op goes through with_node/2. It’s a synchronous GenServer.call — work is serialised through the GenServer’s mailbox. Concurrency lives between dispatches (your sampler chains, your instrument workers), not inside a single dispatch.

%{
  status: Nx.Vulkan.Node.status(),
  queue_len:
    Process.info(vnode, :message_queue_len)
    |> elem(1)
}

Part 2 — A chain shader from a spec

A FamilySpec is a small declarative record describing one distribution family. The Beta family is in the standard catalog:

spec = Nx.Vulkan.ChainShaderSpecs.beta()
spec

Synthesis.compile/1 renders the spec to GLSL via a templated leapfrog skeleton, runs it through glslangValidator, and writes the resulting SPIR-V to a content-addressed cache file:

{cold_us, {:ok, spv_path}} =
  :timer.tc(fn -> Nx.Vulkan.Synthesis.compile(spec) end)

%{cold_us: cold_us, spv_path: spv_path, size_bytes: File.stat!(spv_path).size}

Compile the same spec again. The cache lookup is content-addressed (SHA-256 of the rendered GLSL), so identical specs hit the same cache entry without re-running glslang:

{warm_us, {:ok, ^spv_path}} =
  :timer.tc(fn -> Nx.Vulkan.Synthesis.compile(spec) end)

%{warm_us: warm_us, speedup: Float.round(cold_us / max(warm_us, 1), 1)}

The cold path is bounded by glslangValidator invocation time; the warm path is bounded by a single file-existence check. The difference is usually two orders of magnitude.

compile_with_source/1 is compile/1 plus the rendered GLSL, which is handy when you want to read what the template produced:

{:ok, _spv_path, glsl} = Nx.Vulkan.Synthesis.compile_with_source(spec)

Kino.Markdown.new("""

#{String.slice(glsl, 0, 1200)} …

""")

What’s in there:

  • Push-constant block with n (chain count), K (leapfrog steps), eps (step size), and family-specific parameters (alpha, beta, …).
  • Six read/write SSBOs for input/output positions, momenta, gradients, and per-step log-probabilities.
  • One for k = 0..K-1 loop performing K leapfrog steps before returning to the host. This is the batched leapfrog — one dispatch, K updates.

The template parameterises the leapfrog body on grad_log_p(q) and log_p(q) expressions specific to the family. Beta plugs in α/(α+β·sigmoid(q)) - … etc.; Gamma swaps the gradient. The mechanics outside the per-family body are identical across specs.

Part 3 — Dispatching a chain

The push-constant helper computes the buffer-layout-correct byte sequence for a single Beta family dispatch:

alpha = 2.0
beta = 5.0
n = 1
k = 32
eps = 0.05

# `logp_const` is the family's log-normalising constant (
# logBeta(α,β) for Beta). Precomputed and passed as a uniform.
logp_const = :math.lgamma(alpha) + :math.lgamma(beta) - :math.lgamma(alpha + beta)

push = Nx.Vulkan.ChainShaderSpecs.beta_push(n, k, eps, alpha, beta, logp_const)
byte_size(push)

In a complete dispatch you’d:

  1. Upload q_init, p_init, inv_mass SSBOs.
  2. Call Nx.Vulkan.Native.leapfrog_chain_synth/6 inside with_node/1.
  3. Download q_chain, p_chain, grad_chain, logp_chain.

For brevity here, the exact NIF call shape lives in Nx.Vulkan.Native.leapfrog_chain_synth/6 — see the module docs. What matters at the architectural level: one GenServer call covers K leapfrog steps. The per-call cost is amortised across K, which is why Vulkan-on-BEAM crosses break-even versus the per-step IR walker.

Part 4 — What happens when things break

Three failure modes worth understanding before deploying.

1. A broken spec. Pass a FamilySpec whose GLSL template references a function that doesn’t exist:

broken_spec = %{spec | name: "broken"}

# Pretend we mangled the body somehow.  In practice this happens
# when a new distribution's template hits an unimplemented op.
# Synthesis.compile/1 surfaces a structured error tuple.
case Nx.Vulkan.Synthesis.compile(broken_spec) do
  {:ok, _spv} -> :ok_unmangled
  {:error, reason} -> {:err, reason}
end

(The exact failure shape depends on which template line is broken; glslangValidator‘s stderr ends up in the error term.)

2. A runtime watchdog timeout. Nx.Vulkan.Node.with_node/2 has a configurable watchdog. If a single dispatch hangs (driver bug, device hang, infinite loop in a shader), the watchdog kills the GenServer call and you get {:error, :node_timeout}. eXMC’s NUTS sampler catches this and falls back to its EXLA path. Your own code should treat it the same way: catch, log, fall back or retry.

3. The Node didn’t start. What it looks like at runtime:

# Don't try this here — but conceptually:
# Process.exit(vnode, :kill)
# Nx.Vulkan.Backend.add(...)  # → GenServer.call(nil, ...) — argument error

If your Application.start doesn’t include Nx.Vulkan.Node, your Vulkan-backed ops eventually call into a nil PID. The error is not “Vulkan unavailable” — it’s **(ArgumentError) GenServer call with nil, deep in the Nx call stack. The fix is the supervision-tree integration shown in Part 1.

Wrap-up

The GPU process is the API you actually program against. The shader cache is the API you tune. The two are coupled: a long- lived GenServer is what lets the pipeline cache amortise the SPIR-V compile + pipeline build across thousands of dispatches. A short-lived script that starts and stops Nx.Vulkan.Node per call would build and discard the pipeline cache every time — correctness is fine, throughput is wrecked.

If you take one thing away: put Nx.Vulkan.Node in your supervision tree at boot, behind a compiler-detect guard.

%{vnode: Process.whereis(Nx.Vulkan.Node), spv: spv_path, glsl_bytes: byte_size(glsl)}