Generative Models with Edifice
Setup
Choose one of the two cells below depending on how you started Livebook.
Standalone (default)
Use this if you started Livebook normally (livebook server).
Uncomment the EXLA lines for GPU acceleration.
edifice_dep =
if File.dir?(Path.expand("~/edifice")) do
{:edifice, path: Path.expand("~/edifice")}
else
{:edifice, "~> 0.2.0"}
end
Mix.install([
edifice_dep,
# {:exla, "~> 0.10"},
{:kino_vega_lite, "~> 0.1"},
{:kino, "~> 0.14"}
])
# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
Attached to project (recommended for Nix/CUDA)
Use this if you started Livebook via ./scripts/livebook.sh.
See the Architecture Zoo notebook for full setup instructions.
Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")
Introduction
Generative models learn to produce new data that resembles the training distribution. Unlike classifiers that map inputs to labels, generative models learn the underlying data distribution and can sample from it.
This notebook trains a Variational Autoencoder (VAE) on 2D point clouds. A VAE works by learning two things simultaneously:
- An encoder that compresses each data point into a compact “latent code”
- A decoder that reconstructs data from latent codes
The key insight of a VAE (versus a plain autoencoder) is that the encoder doesn’t output a single point — it outputs a probability distribution (parameterized by mu and log_var). A special loss term called KL divergence pushes these distributions toward a known prior (standard normal). This means we can generate new data by sampling random points from that prior and decoding them.
What you’ll learn:
- How VAEs encode data into a latent distribution (mu + log_var)
- Training with reconstruction loss + KL divergence
- How the beta parameter controls the reconstruction vs regularization trade-off
- Visualizing the learned latent space
- Generating new samples by decoding random latent points
Generate 2D Data
We create a dataset of 2D points arranged in two crescent moon shapes — one arching up, one arching down, offset so they interlock. This is a classic test for generative models because:
- The data has non-trivial structure (curved, multi-modal)
- It’s 2D, so we can visualize everything directly
- A good generative model should produce new points that follow the crescent shapes, not just fill in a blob
IO.puts("Generating crescent moon dataset...")
n_points = 1500
key = Nx.Random.key(42)
# Upper crescent: half-circle arc from 0 to pi, with Gaussian noise
{angles_upper, key} = Nx.Random.uniform(key, shape: {n_points})
angles_upper = Nx.multiply(angles_upper, :math.pi())
{noise_u, key} = Nx.Random.normal(key, shape: {n_points, 2})
upper_x = Nx.add(Nx.cos(angles_upper), Nx.multiply(noise_u[[.., 0]], 0.1))
upper_y = Nx.add(Nx.sin(angles_upper), Nx.multiply(noise_u[[.., 1]], 0.1))
upper = Nx.stack([upper_x, upper_y], axis: 1)
# Lower crescent: same arc, but shifted right and flipped downward
{angles_lower, key} = Nx.Random.uniform(key, shape: {n_points})
angles_lower = Nx.multiply(angles_lower, :math.pi())
{noise_l, key} = Nx.Random.normal(key, shape: {n_points, 2})
lower_x = Nx.add(Nx.subtract(1.0, Nx.cos(angles_lower)), Nx.multiply(noise_l[[.., 0]], 0.1))
lower_y = Nx.subtract(Nx.subtract(0.0, Nx.sin(angles_lower)), Nx.add(0.5, Nx.multiply(noise_l[[.., 1]], 0.1)))
lower = Nx.stack([lower_x, lower_y], axis: 1)
data = Nx.concatenate([upper, lower])
labels = Nx.concatenate([Nx.broadcast(0, {n_points}), Nx.broadcast(1, {n_points})])
# Shuffle so the two classes are interleaved
n_total = 2 * n_points
{shuffle_noise, _key} = Nx.Random.uniform(Nx.Random.key(99), shape: {n_total})
shuffle_idx = Nx.argsort(shuffle_noise)
data = Nx.take(data, shuffle_idx)
labels = Nx.take(labels, shuffle_idx)
# Batch for training — each batch is {input, target} where input == target
# (the VAE learns to reconstruct its own input)
batch_size = 64
IO.puts(" Batching #{n_total} points into batches of #{batch_size}...")
train_batches =
Nx.to_batched(data, batch_size)
|> Enum.to_list()
|> Enum.map(fn batch -> {batch, batch} end)
IO.puts("Ready: #{n_total} points, #{length(train_batches)} batches/epoch")
chart_data =
Enum.zip_with(
[Nx.to_flat_list(data[[.., 0]]), Nx.to_flat_list(data[[.., 1]]), Nx.to_flat_list(labels)],
fn [x, y, l] -> %{"x" => x, "y" => y, "class" => if(trunc(l) == 0, do: "upper", else: "lower")} end
)
Vl.new(width: 500, height: 350, title: "Training Data: Two Crescents")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:circle, size: 15, opacity: 0.5)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "class", type: :nominal)
Build the VAE
A VAE has three logical parts:
-
Encoder: Takes input data and produces two vectors — mu (the mean) and log_var (the log-variance). Together these define a Gaussian distribution in latent space for each input point. Instead of compressing to a single point, the encoder says “this input probably came from somewhere around here“ — that uncertainty is what enables generation.
-
Latent sampling: In a full VAE, you’d sample z = mu + exp(0.5 log_var) epsilon where epsilon ~ N(0, I). This “reparameterization trick” makes sampling differentiable. Here we feed mu directly to the decoder for simplicity — the KL loss still regularizes the distribution.
-
Decoder: Takes a latent code and reconstructs the original input. During generation, we skip the encoder entirely and feed random samples from N(0, I) directly to the decoder.
We use a 2D latent space (latent_size=2) so we can plot it directly. With higher-dimensional latent spaces you’d need PCA or t-SNE to visualize.
The model uses Axon.container to output a map with all three components
(:reconstruction, :mu, :log_var) so the loss function can access each one.
IO.puts("Building VAE (2D latent space)...")
latent_size = 2
input_size = 2
# --- Encoder ---
# Two dense layers compress the 2D input into a 32-dim hidden representation,
# then two separate linear layers project to mu and log_var.
input = Axon.input("input", shape: {nil, input_size})
enc =
input
|> Axon.dense(64, name: "enc_0")
|> Axon.activation(:relu)
|> Axon.dense(32, name: "enc_1")
|> Axon.activation(:relu)
# mu: the "center" of where this input maps in latent space
mu = Axon.dense(enc, latent_size, name: "mu")
# log_var: how "spread out" the encoding is (log of variance, not variance
# itself, because log_var can be negative which is numerically easier to learn)
log_var = Axon.dense(enc, latent_size, name: "log_var")
# --- Decoder ---
# Mirror of the encoder: takes a latent code and expands back to 2D
recon =
mu
|> Axon.dense(32, name: "dec_0")
|> Axon.activation(:relu)
|> Axon.dense(64, name: "dec_1")
|> Axon.activation(:relu)
|> Axon.dense(input_size, name: "dec_out")
# --- Combined model ---
# Axon.container bundles multiple outputs into a map so the loss function
# can access reconstruction, mu, and log_var separately
vae_model = Axon.container(%{reconstruction: recon, mu: mu, log_var: log_var})
# --- Encoder-only model (for visualization later) ---
# This is a separate Axon graph, but it uses the SAME layer names as the
# encoder above. When we pass the trained weights to this model, Axon
# matches by name, so it automatically uses the trained encoder weights.
encoder_model =
Axon.input("input", shape: {nil, input_size})
|> Axon.dense(64, name: "enc_0")
|> Axon.activation(:relu)
|> Axon.dense(32, name: "enc_1")
|> Axon.activation(:relu)
|> Axon.dense(latent_size, name: "mu")
IO.puts(" Encoder: {batch, 2} -> mu {batch, 2} + log_var {batch, 2}")
IO.puts(" Decoder: mu {batch, 2} -> reconstruction {batch, 2}")
IO.puts(" Output container: :reconstruction, :mu, :log_var")
Train the VAE
The VAE loss has two parts that pull in opposite directions:
-
Reconstruction loss (MSE): “Make the output match the input.” This encourages the encoder to preserve as much information as possible and the decoder to faithfully reconstruct.
-
KL divergence: “Keep the latent distribution close to N(0, I).” This encourages the encoder to spread its codes near the origin with unit variance, rather than memorizing each input as a far-flung point.
Without KL, you’d just have a plain autoencoder — great reconstructions, but useless for generation because the latent space has no structure. Without reconstruction loss, the encoder would map everything to N(0, I) and the decoder would output the mean of the dataset.
The beta parameter controls this trade-off:
-
beta = 1.0is the standard VAE (ELBO objective) -
beta < 1.0prioritizes reconstruction — the latent space spreads out more, giving the decoder room to learn fine details -
beta > 1.0prioritizes regularization — smoother latent space but blurrier reconstructions (beta-VAE)
We use beta = 0.5 — strong enough to regularize the latent space for
generation, but not so strong that the KL term dominates before the decoder
learns. If generated samples look too spread out, try more epochs or lower
beta. If the latent space collapses to a single point, lower beta.
# Beta controls the KL vs reconstruction trade-off
# Try 0.01 (almost autoencoder) to 1.0 (full VAE) and see how it changes!
beta = 0.5
epochs = 30
# Custom VAE loss: MSE reconstruction + beta * KL divergence
# IMPORTANT: Axon losses use (y_true, y_pred) argument order — targets first!
vae_loss = fn y_true, y_pred ->
# Reconstruction: how close is the decoder output to the original input?
recon_loss = Nx.mean(Nx.pow(Nx.subtract(y_pred.reconstruction, y_true), 2))
# KL divergence: how far is q(z|x) from the prior p(z) = N(0, I)?
# Formula: KL = -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
# When mu=0 and log_var=0 (i.e. variance=1), KL = 0 (matches the prior)
kl =
Nx.subtract(
Nx.add(1.0, y_pred.log_var),
Nx.add(Nx.pow(y_pred.mu, 2), Nx.exp(y_pred.log_var))
)
|> Nx.sum(axes: [-1])
|> Nx.mean()
|> Nx.multiply(-0.5)
Nx.add(recon_loss, Nx.multiply(beta, kl))
end
IO.puts("Training VAE (beta=#{beta}, #{epochs} epochs)...")
trained_state =
vae_model
|> Axon.Loop.trainer(vae_loss, Polaris.Optimizers.adam(learning_rate: 3.0e-3))
|> Axon.Loop.run(train_batches, Axon.ModelState.empty(), epochs: epochs)
IO.puts("Training complete!")
Visualize the Latent Space
Now let’s see what the encoder learned. We pass every data point through the encoder and plot its mu vector (the “center” of its latent distribution).
What to look for:
- The two crescent classes should form distinct clusters — the encoder learned that upper and lower crescents are different
- The points should be spread across the latent space, not collapsed into a tiny dot (which would mean KL dominated too much)
-
With
beta=0.5, you should see good separation with reasonable spread
IO.puts("Encoding all data points into latent space...")
# Use the encoder-only model with the trained weights.
# Even though we trained vae_model (which has encoder + decoder layers),
# encoder_model uses the same layer names ("enc_0", "enc_1", "mu"),
# so Axon automatically loads the matching weights.
{_enc_init, enc_predict} = Axon.build(encoder_model)
mu_out = enc_predict.(trained_state, data)
latent_data =
Enum.zip_with(
[Nx.to_flat_list(mu_out[[.., 0]]),
Nx.to_flat_list(mu_out[[.., 1]]),
Nx.to_flat_list(labels)],
fn [z1, z2, l] ->
%{"z1" => z1, "z2" => z2, "class" => if(trunc(l) == 0, do: "upper", else: "lower")}
end
)
Vl.new(width: 500, height: 350, title: "Latent Space (encoder mu)")
|> Vl.data_from_values(latent_data)
|> Vl.mark(:circle, size: 15, opacity: 0.5)
|> Vl.encode_field(:x, "z1", type: :quantitative, title: "z1")
|> Vl.encode_field(:y, "z2", type: :quantitative, title: "z2")
|> Vl.encode_field(:color, "class", type: :nominal)
Generate New Samples
This is the payoff of training a VAE: generating new data.
The idea is simple: during training, KL divergence pushed the encoder’s latent distributions toward N(0, I). So the decoder has learned to turn points near the origin into crescent-shaped data. To generate, we:
- Sample random points from the latent space
- Feed them through the decoder
- Get new 2D points that should look like the training data
In a perfectly trained VAE (beta=1.0, many epochs), you’d sample from pure N(0, I). In practice — especially with lower beta or fewer epochs — the encoder’s latent distribution doesn’t perfectly match N(0, I). So we sample from the encoder’s actual distribution: we compute the mean and std of all encoded mu vectors and sample from that. This gives better results and shows what the encoder actually learned.
We build a standalone decoder — a separate Axon graph that uses the same layer names (“dec_0”, “dec_1”, “dec_out”) as the decoder inside the VAE. This lets us reuse the trained weights without needing the encoder.
IO.puts("Building standalone decoder for generation...")
# Standalone decoder: same layer names as the decoder in vae_model
latent_input = Axon.input("latent", shape: {nil, latent_size})
standalone_decoder =
latent_input
|> Axon.dense(32, name: "dec_0")
|> Axon.activation(:relu)
|> Axon.dense(64, name: "dec_1")
|> Axon.activation(:relu)
|> Axon.dense(input_size, name: "dec_out")
{_dec_init, dec_predict_fn} = Axon.build(standalone_decoder)
# Sample from the encoder's actual latent distribution
# (for a fully-converged VAE with beta=1.0, this would be close to N(0, I))
IO.puts("Generating new samples from latent space...")
n_samples = 500
latent_mean = Nx.mean(mu_out, axes: [0])
latent_std = Nx.standard_deviation(mu_out, axes: [0])
{z_noise, _k} = Nx.Random.normal(Nx.Random.key(123), shape: {n_samples, latent_size})
z_samples = Nx.add(latent_mean, Nx.multiply(latent_std, z_noise))
IO.puts(" Latent mean: #{inspect(Nx.to_flat_list(latent_mean))}")
IO.puts(" Latent std: #{inspect(Nx.to_flat_list(latent_std))}")
generated = dec_predict_fn.(trained_state, %{"latent" => z_samples})
gen_data =
Enum.zip_with(
[Nx.to_flat_list(generated[[.., 0]]), Nx.to_flat_list(generated[[.., 1]])],
fn [x, y] -> %{"x" => x, "y" => y, "source" => "generated"} end
)
real_data =
Enum.zip_with(
[Nx.to_flat_list(data[[.., 0]]), Nx.to_flat_list(data[[.., 1]])],
fn [x, y] -> %{"x" => x, "y" => y, "source" => "real"} end
)
Vl.new(width: 500, height: 350, title: "Real vs Generated Samples")
|> Vl.data_from_values(real_data ++ gen_data)
|> Vl.mark(:circle, size: 15, opacity: 0.4)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "source", type: :nominal)
Decode a Grid of Latent Points
This visualization shows what the decoder has learned as a function. We create a uniform grid of points across the encoder’s actual latent range (not a fixed [-3, 3]) and decode each one, then plot where it lands in data space.
What to look for:
- The grid should map to crescent-like shapes since we’re sampling from the region the encoder actually uses
- The color gradient shows how the decoder maps latent space to data space — a smooth gradient means nearby latent points produce nearby data points (the decoder learned a smooth, continuous mapping)
- With beta < 1.0, the encoder may use a small region of latent space, so the grid is automatically scaled to that region using the encoder’s computed mean and std
IO.puts("Decoding latent space grid...")
# Use the encoder's actual range (3 std devs around the mean)
# instead of a fixed [-3, 3] which may be far outside the learned region
grid_res = 20
grid_half = 3.0 # number of std devs to cover
grid_lo = Nx.to_flat_list(Nx.subtract(latent_mean, Nx.multiply(latent_std, grid_half)))
grid_hi = Nx.to_flat_list(Nx.add(latent_mean, Nx.multiply(latent_std, grid_half)))
grid_range_z1 = Enum.map(0..(grid_res - 1), fn i ->
Enum.at(grid_lo, 0) + (Enum.at(grid_hi, 0) - Enum.at(grid_lo, 0)) * i / (grid_res - 1)
end)
grid_range_z2 = Enum.map(0..(grid_res - 1), fn i ->
Enum.at(grid_lo, 1) + (Enum.at(grid_hi, 1) - Enum.at(grid_lo, 1)) * i / (grid_res - 1)
end)
grid_z =
for z1 <- grid_range_z1, z2 <- grid_range_z2 do
[z1, z2]
end
grid_tensor = Nx.tensor(grid_z)
decoded_grid = dec_predict_fn.(trained_state, %{"latent" => grid_tensor})
grid_data =
Enum.zip_with(
[Nx.to_flat_list(decoded_grid[[.., 0]]),
Nx.to_flat_list(decoded_grid[[.., 1]]),
Enum.map(grid_z, fn [z1, _] -> z1 end),
Enum.map(grid_z, fn [_, z2] -> z2 end)],
fn [x, y, z1, z2] ->
%{"x" => x, "y" => y, "z1" => Float.round(z1, 1), "z2" => Float.round(z2, 1)}
end
)
Vl.new(width: 500, height: 350, title: "Decoded Latent Grid (color = z1 position)")
|> Vl.data_from_values(grid_data)
|> Vl.mark(:circle, size: 30, opacity: 0.6)
|> Vl.encode_field(:x, "x", type: :quantitative, title: "Decoded x")
|> Vl.encode_field(:y, "y", type: :quantitative, title: "Decoded y")
|> Vl.encode_field(:color, "z1", type: :quantitative, scale: %{scheme: "viridis"}, title: "z1")
Key Takeaways
-
Encode to a distribution, not a point: The encoder outputs mu and log_var — a full Gaussian distribution per input. This is the core difference from a plain autoencoder and what enables generation.
-
KL divergence is the regularizer: Without KL, the encoder memorizes each input as an isolated point in latent space (useless for generation). KL pushes the latent distributions toward N(0, I), ensuring the decoder learns to handle the region around the origin — which is exactly where we sample from during generation.
-
Beta controls the trade-off: With
beta=1.0(standard VAE), KL can dominate early in training, causing “posterior collapse” where everything maps to the origin. Lower beta (like 0.1) lets the decoder learn first. Higher beta gives a smoother latent space at the cost of reconstruction quality (beta-VAE). -
2D latent space = direct visualization: With latent_size=2, we can literally plot where each data point lands in latent space and see the decoder’s mapping from latent to data space. Real VAEs use much larger latent spaces (64, 128, 256+) but need PCA/t-SNE to visualize.
What’s Next?
-
Experiment with beta: Try
beta = 0.01(nearly autoencoder — sharp but poor generation) vsbeta = 1.0(full VAE — smooth latent space but may need more epochs) to see the trade-off firsthand. -
More epochs: Increase to 30-50 for cleaner crescent generation.
With more training, even
beta = 1.0produces good results. -
Try higher-dimensional latent spaces:
latent_size: 8or16for more capacity — you’ll need dimensionality reduction (PCA/t-SNE) to visualize the latent space. -
Normalizing Flow:
Edifice.build(:normalizing_flow, ...)gives exact log-likelihood instead of the ELBO approximation. -
VQ-VAE: Discrete latent codes via
Edifice.build(:vq_vae, ...)— often produces sharper results. -
Diffusion models: See the Diffusion notebook for step-by-step
denoising with
Edifice.build(:dit, ...).