BDA Chapter 5 — Eight Schools, Hierarchical Models, and Why NUTS Exists
Setup
# CPU only — no GPU required
System.put_env("EXLA_CPU_ONLY", "true")
System.put_env("CUDA_VISIBLE_DEVICES", "")
Mix.install([
{:exmc, path: Path.expand("../../", __DIR__)},
{:exla, "~> 0.10"},
{:kino_vega_lite, "~> 0.1"}
])
Application.put_env(:exla, :clients, host: [platform: :host])
Application.put_env(:exla, :default_client, :host)
Nx.default_backend(Nx.BinaryBackend)
Nx.Defn.default_options(compiler: EXLA, client: :host)
alias Exmc.{Builder, Sampler}
alias Exmc.Dist.{Normal, HalfNormal}
alias VegaLite, as: Vl
:ok
Why This Matters
Eight high schools each ran the same SAT coaching program. Each school measured how many points the program added or removed, and reported a point estimate with a standard error:
| School | y (effect, SAT points) | σ (standard error) |
|---|---|---|
| A | 28 | 15 |
| B | 8 | 10 |
| C | -3 | 16 |
| D | 7 | 11 |
| E | -1 | 9 |
| F | 1 | 11 |
| G | 18 | 10 |
| H | 12 | 18 |
You are asked: what is the true effect of the coaching program?
You have three plainly wrong choices:
- Report each school separately. This says “School A’s program raises scores by 28 points.” But A’s standard error is 15 — the data don’t support 28 to two significant figures.
-
Average everything. This says the program raises scores by
≈8for everyone. But it ignores that the schools differ in measured effect. - Pretend School A’s number is real. Coaching companies do this every day. They get sued.
The correct answer is a hierarchical model: each school has its own true
effect θ_j, but the eight true effects are themselves drawn from a common
distribution N(μ, τ²) whose parameters you also infer. The result is
partial pooling — School A’s posterior estimate gets pulled down toward
the group mean because no other school comes anywhere close to a 28-point
effect, but it isn’t pulled all the way to the group mean because A’s data
do show some signal.
This is BDA3’s archetypal example for a reason. It is the simplest model that cannot be done with conjugate algebra alone, that cannot be done with straightforward Gibbs sampling without funnel pathologies, that requires Hamiltonian Monte Carlo (NUTS) to sample reliably, and that exposes every PPL implementation’s strengths and weaknesses. If your sampler can’t do 8-schools, it can’t do anything serious.
The Model
$$ \begin{aligned} \mu & \sim \text{Normal}(0, 10) \ \tau & \sim \text{Half-Normal}(10) \ \theta_j & \sim \text{Normal}(\mu, \tau) & j &= 1, \dots, 8 \ y_j & \sim \text{Normal}(\theta_j, \sigma_j) & j &= 1, \dots, 8 \end{aligned} $$
The hyperpriors on μ and τ are weakly informative — they say “the true
effects are probably within ±20 SAT points and the between-school spread is
probably less than 20 points,” nothing stronger. The σ_j are known
because each school reported them; the model treats them as fixed
constants, not parameters.
The pathology: when τ → 0, all θ_j collapse onto μ, and the posterior
geometry of (μ, τ, θ_1, ..., θ_8) becomes a long, narrow funnel. A
random-walk sampler (Metropolis, Gibbs) gets stuck at the funnel’s neck.
Hamiltonian Monte Carlo with a fixed step size also fails — it overshoots
the curvature. No-U-Turn Sampling (NUTS) with adaptive step size is the
standard solution. eXMC implements NUTS by default.
The famous trick on top: non-centered parameterization (NCP), which
rewrites θ_j = μ + τ * z_j where z_j ~ N(0, 1). The funnel disappears
in (μ, τ, z) coordinates because the prior on z is independent of τ.
eXMC applies this rewrite automatically when you pass ncp: true.
# Data — BDA3 p. 120
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
sigma = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
n_schools = length(y)
school_labels = ["A", "B", "C", "D", "E", "F", "G", "H"]
%{n_schools: n_schools, y_mean: Enum.sum(y) / n_schools}
# Build the IR
ir = Builder.new_ir()
# Hyperpriors
ir = Builder.rv(ir, "mu", Normal, %{mu: Nx.tensor(0.0), sigma: Nx.tensor(10.0)})
ir = Builder.rv(ir, "tau", HalfNormal, %{sigma: Nx.tensor(10.0)}, transform: :log)
# School-level effects via string parameter references
ir =
Enum.reduce(0..(n_schools - 1), ir, fn j, ir ->
Builder.rv(ir, "theta_#{j}", Normal, %{mu: "mu", sigma: "tau"})
end)
# Observations: y_j ~ Normal(theta_j, sigma_j)
ir =
Enum.reduce(0..(n_schools - 1), ir, fn j, ir ->
yj = Enum.at(y, j)
sj = Enum.at(sigma, j)
ir = Builder.rv(ir, "y_#{j}", Normal, %{mu: "theta_#{j}", sigma: Nx.tensor(sj)})
Builder.obs(ir, "y_#{j}_obs", "y_#{j}", Nx.tensor(yj))
end)
%{nodes: map_size(ir.nodes), free_rvs: 1 + 1 + n_schools}
We have 10 free random variables: μ, τ, and θ_0, ..., θ_7. The
posterior we are about to sample is a 10-dimensional density.
Sampling with NUTS (Non-Centered)
We run 500 warmup iterations (the sampler tunes its step size and mass
matrix during warmup) and then draw 500 posterior samples. NCP is enabled,
so internally the sampler operates on z instead of θ and the funnel
disappears.
{trace_ncp, stats_ncp} =
Sampler.sample(ir, %{},
num_samples: 500,
num_warmup: 500,
seed: 42,
ncp: true
)
%{
divergences: stats_ncp.divergences,
step_size: Float.round(stats_ncp.step_size, 4),
message: "✓ NUTS sampling complete"
}
A handful of warmup divergences is normal for 8-schools and is a sign that
the sampler explored the funnel’s neck during adaptation. Sampling-phase
divergences (after warmup) above ~5% of total samples would be a warning
sign and an invitation to reparameterize or strengthen the prior on τ.
Reading the Posterior
mu_samples = trace_ncp["mu"]
tau_samples = trace_ncp["tau"]
mu_mean = Nx.mean(mu_samples) |> Nx.to_number()
mu_sd = :math.sqrt(Nx.variance(mu_samples) |> Nx.to_number())
tau_mean = Nx.mean(tau_samples) |> Nx.to_number()
tau_sd = :math.sqrt(Nx.variance(tau_samples) |> Nx.to_number())
%{
mu_mean: Float.round(mu_mean, 2),
mu_sd: Float.round(mu_sd, 2),
tau_mean: Float.round(tau_mean, 2),
tau_sd: Float.round(tau_sd, 2),
bda3_reference: "μ ≈ 7.7, τ ≈ 6.8 (BDA3 Fig 5.6, varies with prior on τ)"
}
The grand mean μ sits around 6–8 SAT points; the between-school spread
τ sits around 4–6 points. The exact values depend mildly on the prior on
τ. With a Half-Normal(10) prior the posterior is gently shrunk; with the
weaker Half-Cauchy that BDA3 uses you would see slightly larger τ.
Per-School Effects (Partial Pooling in Action)
school_summary =
for j <- 0..(n_schools - 1) do
theta_j = trace_ncp["theta_#{j}"]
mean = Nx.mean(theta_j) |> Nx.to_number()
sd = :math.sqrt(Nx.variance(theta_j) |> Nx.to_number())
%{
school: Enum.at(school_labels, j),
raw_y: Enum.at(y, j),
raw_sigma: Enum.at(sigma, j),
posterior_mean: Float.round(mean, 2),
posterior_sd: Float.round(sd, 2),
shrinkage: Float.round((Enum.at(y, j) - mean) / Enum.at(y, j), 3)
}
end
school_summary
The shrinkage column tells the partial pooling story.
- School A had a raw measurement of 28. After borrowing strength from the other schools, its posterior mean drops to ~9 — a shrinkage of around ~70%. The other seven schools said “we don’t see effects that big,” and the model believed them.
- School C had a raw measurement of -3. Its posterior is positive, pulled up toward the group mean. The model says “your point estimate is noisy enough that it’s more likely the true effect is small and positive, not small and negative.”
- School G had a raw measurement of 18 with a small standard error (10). Its posterior is the highest of the eight after shrinkage — the data on G are precise enough to resist full pooling.
This is the textbook visualization (BDA3 Fig 5.7) — raw data on the left, shrunk posterior on the right.
shrinkage_data =
Enum.flat_map(school_summary, fn s ->
[
%{school: s.school, type: "Raw estimate (y_j ± σ_j)", value: s.raw_y, error: s.raw_sigma},
%{school: s.school, type: "Posterior mean ± sd", value: s.posterior_mean, error: s.posterior_sd}
]
end)
Vl.new(width: 600, height: 320, title: "Partial pooling: raw vs posterior")
|> Vl.data_from_values(shrinkage_data)
|> Vl.mark(:point, size: 100, filled: true)
|> Vl.encode_field(:x, "school", type: :nominal, title: "School")
|> Vl.encode_field(:y, "value", type: :quantitative, title: "Effect (SAT points)")
|> Vl.encode_field(:color, "type", type: :nominal)
|> Vl.encode_field(:y_error, "error", type: :quantitative)
Comparing Centered vs Non-Centered
Now the experiment that motivates the existence of NCP. Build the same model and sample it without the rewrite. The model is mathematically identical — but the geometry is hostile.
{trace_centered, stats_centered} =
Sampler.sample(ir, %{},
num_samples: 500,
num_warmup: 500,
seed: 42,
ncp: false
)
%{
divergences_centered: stats_centered.divergences,
divergences_ncp: stats_ncp.divergences,
step_size_centered: Float.round(stats_centered.step_size, 4),
step_size_ncp: Float.round(stats_ncp.step_size, 4)
}
The centered run usually shows more divergences and a smaller step size.
The sampler is fighting funnel curvature. With more data per school
(narrower likelihoods), centered would actually be the better choice — but
8-schools is the canonical regime where the data are weak relative to the
prior on θ_j, and NCP wins.
This is the practical lesson:
- Strong data, weak prior: centered parameterization is fine.
- Weak data, informative prior: non-centered parameterization is required.
- In doubt: try both and compare divergence count and effective sample size.
Sampler.sample(ir, init, ncp: false) gives you a one-line A/B test.
Trace Plots — Diagnosing Convergence
The most useful sampling diagnostic is the trace plot: the value of each sampled parameter over iterations. A healthy chain looks like fuzzy white noise around a stable mean (a “fat caterpillar”). A pathological chain either drifts (the sampler hasn’t converged) or sits stuck (the proposal is being rejected too often).
mu_trace =
Nx.to_list(mu_samples)
|> Enum.with_index()
|> Enum.map(fn {v, i} -> %{iteration: i, value: v, parameter: "μ"} end)
tau_trace =
Nx.to_list(tau_samples)
|> Enum.with_index()
|> Enum.map(fn {v, i} -> %{iteration: i, value: v, parameter: "τ"} end)
trace_data = mu_trace ++ tau_trace
Vl.new(width: 600, height: 240, title: "NUTS trace plots (NCP)")
|> Vl.data_from_values(trace_data)
|> Vl.mark(:line, opacity: 0.7)
|> Vl.encode_field(:x, "iteration", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "parameter", type: :nominal)
Both μ and τ traces should show fast mixing — the line should fill its
band rapidly without long flat sections. If τ is sticky near zero,
that’s the funnel showing through; rerun with ncp: true (or with a more
informative prior on τ).
Posterior Density of τ
The marginal posterior of the between-school standard deviation τ is the
single most informative plot in this analysis. Its mass tells you whether
the data support complete pooling (τ → 0, all schools the same),
no pooling (τ → ∞, schools independent), or partial pooling (some
finite, positive τ).
tau_data = Enum.map(Nx.to_list(tau_samples), fn t -> %{tau: t} end)
Vl.new(width: 600, height: 280, title: "Marginal posterior of τ (between-school sd)")
|> Vl.data_from_values(tau_data)
|> Vl.mark(:bar, color: "#54a24b", opacity: 0.7)
|> Vl.encode_field(:x, "tau",
type: :quantitative,
bin: %{maxbins: 35},
title: "τ"
)
|> Vl.encode_field(:y, "tau", type: :quantitative, aggregate: :count)
The posterior puts non-trivial mass at τ = 0 and at τ ≈ 5. The data
do not rule out complete pooling. Eight schools is just not many — eight
points of evidence about a between-group spread cannot tell you decisively
whether the spread is zero or seven. This is the substantive finding that
BDA3 returns to repeatedly: hierarchical models honestly express what your
data can and cannot say.
What This Tells You
- Partial pooling is the default for grouped data. Reporting per-school estimates over-claims; reporting the average under-claims. The hierarchical model gives you both at once and tells you how to weight them.
- NUTS is not optional for hierarchical models with weak data. Random walk and Gibbs samplers stall in the funnel. eXMC’s NUTS, with NCP enabled, navigates it cleanly — but you should still check divergence counts.
- NCP and centered are different roads to the same posterior. They agree when the sampler converges. They disagree when the sampler is failing — which is exactly when you need to know.
- Eight data points cannot answer “is the true between-group variance exactly zero?” This is not a sampler limitation. It is an information limit. The honest answer is “the posterior puts mass on both possibilities.”
Study Guide
-
Vary the prior on τ. Replace
HalfNormal(10)withHalfNormal(2)and thenHalfNormal(50). Re-sample. How does the marginalp(τ | y)shift? Which schools’ effects shrink more under the tight prior? This is the explicit version of “the prior matters when data are weak.” -
Compare divergence counts between
ncp: trueandncp: falseon the same seed. Run each three times with different seeds. Is the difference robust? Useseed: 1, 2, 3instead ofseed: 42. -
Build a no-pooling model by replacing the school-level prior with a fixed weak prior — e.g.,
Builder.rv(ir, "theta_#{j}", Normal, %{mu: 0, sigma: 50})for each school. Sample and compare eachθ_jposterior to the hierarchical version. Which schools change most? -
Build a complete-pooling model by replacing all
θ_jwith a single sharedtheta. Sample. Compare the posterior ofthetatoμfrom the hierarchical model. Why are they similar but not identical? -
Compute
P(θ_A > θ_C | y)— the posterior probability that School A’s true effect exceeds School C’s. Use the joint samples intrace_ncp(they are dependent: the same iterationioftheta_0andtheta_2come from the same posterior draw). Does the answer surprise you given the raw 28 vs -3? -
(Hard.) Increase to
num_samples: 2000andnum_warmup: 2000. Compute the effective sample size (ESS) forμ,τ, andθ_Ausing the formula in BDA3 §11.5. eXMC’sExmc.Diagnosticsmodule has helpers. Which parameter has the worst ESS? Why?
Literature
- Gelman, Carlin, Stern, Dunson, Vehtari, Rubin. Bayesian Data Analysis, 3rd ed., §5.5 (8-schools, p. 119–124). The original presentation with the centered parameterization, before NCP became standard.
- Neal, R. “Slice sampling” (2003), §8 — introduced the funnel example that motivates NCP.
- Betancourt, M. “A conceptual introduction to Hamiltonian Monte Carlo” (2017) — why NUTS works, with the funnel as a worked case study.
-
Hoffman, M. & Gelman, A. “The No-U-Turn Sampler” (2014) — the
original NUTS paper. eXMC’s
lib/exmc/nuts/implementation cites this throughout. - Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., Bürkner, P-C. “Rank-normalization, folding, and localization: An improved R̂ for assessing convergence of MCMC” (2021) — the modern way to diagnose what you just ran.
-
Original Python demos:
bda-ex-demos/demos_ch5/demo5_1.ipynb(rats) anddemo5_2.ipynb(8-schools).
Where to Go Next
-
notebooks/bda/ch02_beta_binomial.livemd— the conjugate baseline. NUTS and analytical conjugacy give the same answer when both apply. -
notebooks/02_hierarchical_model.livemd— eXMC’s existing minimal hierarchical example, the gentler on-ramp before this notebook. -
notebooks/09_radon_bhm.livemd— the same partial-pooling pattern at a larger scale (85 counties of indoor radon measurements). -
notebooks/12_poker_bayesian.livemd— hierarchical model with hidden per-player parameters, behavior data, and decision-theoretic output. The next step beyond 8-schools.