Powered by AppSignal & Oban Pro

Edifice: A Guided Tour

notebooks/demo_for_dad.livemd

Edifice: A Guided Tour

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server). Uncomment the EXLA lines for GPU acceleration.

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino_vega_lite, "~> 0.1"},
  {:kino, "~> 0.14"}
])

# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh. See the Architecture Zoo notebook for full setup instructions.

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")

1. What Is Edifice?

Edifice is an Elixir library that provides 100+ neural network architectures through a single, consistent API. Think of it as a toolbox where each tool is a different way to learn patterns from data — and you can swap tools with one line of code.

If you’ve worked with spreadsheets, regression models, or statistical software, you already understand the core idea: you have data, you want to find patterns in it, and you want to use those patterns to make predictions.

Neural networks do exactly that. They’re fancy curve-fitting machines. A logistic regression finds a straight line (or plane) that separates groups. A neural network finds curved, flexible boundaries — and the more layers you add, the more complex shapes it can learn.

What you’ll see in this notebook:

  • Building and training models with just a few lines of code
  • Classification, tabular data analysis, time series forecasting
  • How different architectures suit different problems (like choosing the right statistical test for your data type)
  • Models that quantify their own uncertainty — because a prediction without a confidence interval isn’t very useful
  • Generating synthetic data — useful for stress testing and simulation
  • Visualizations at every step so you can see what’s happening

The bottom line: Edifice makes it easy to experiment with many different approaches to the same problem. Instead of writing different code for each architecture, you change one word and the rest stays the same.

2. The Simplest Example — Classification

Let’s start with something familiar: putting data points into groups.

Imagine you have a scatter plot of insurance applicants. The X axis is age, the Y axis is number of prior claims. You can see two clusters — low-risk and high-risk. A logistic regression draws a straight line between them. A neural network can draw a curved line if the boundary isn’t straight.

We’ll generate simple 2D data (two groups of points), train a small neural network, and watch it learn to classify.

# Generate two clusters of points — like plotting applicants on a risk chart
key = Nx.Random.key(42)
n_per_class = 400

IO.puts("Generating #{n_per_class * 2} data points in two groups...")

# Group A: centered at (-1.2, -0.8) — think "low-risk applicants"
{noise_a, key} = Nx.Random.normal(key, shape: {n_per_class, 2})
group_a = Nx.add(Nx.multiply(noise_a, 0.7), Nx.tensor([-1.2, -0.8]))

# Group B: centered at (1.2, 0.8) — think "high-risk applicants"
{noise_b, key} = Nx.Random.normal(key, shape: {n_per_class, 2})
group_b = Nx.add(Nx.multiply(noise_b, 0.7), Nx.tensor([1.2, 0.8]))

# Combine and shuffle
x_all = Nx.concatenate([group_a, group_b])
y_raw = Nx.concatenate([Nx.broadcast(0, {n_per_class}), Nx.broadcast(1, {n_per_class})])

{shuffle_noise, _key} = Nx.Random.uniform(Nx.Random.key(99), shape: {n_per_class * 2})
shuffle_idx = Nx.argsort(shuffle_noise)
x_all = Nx.take(x_all, shuffle_idx)
y_raw = Nx.take(y_raw, shuffle_idx)

# One-hot encode: [0] -> [1, 0], [1] -> [0, 1]
y_onehot = Nx.equal(Nx.new_axis(y_raw, 1), Nx.tensor([[0, 1]])) |> Nx.as_type(:f32)

# 80/20 train/test split
n_train = round(n_per_class * 2 * 0.8)

train_x = x_all[0..(n_train - 1)]
train_y = y_onehot[0..(n_train - 1)]
test_x = x_all[n_train..-1//1]
test_y = y_onehot[n_train..-1//1]
test_labels = y_raw[n_train..-1//1]

# Batch the training data (the model sees 32 examples at a time)
batch_size = 32

train_data =
  Enum.zip(
    Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
    Nx.to_batched(train_y, batch_size) |> Enum.to_list()
  )

IO.puts("Ready: #{n_train} training samples, #{Nx.axis_size(test_x, 0)} test samples")

Let’s see what the data looks like before we train anything:

scatter_data =
  Enum.zip_with(
    [Nx.to_flat_list(x_all[[.., 0]]), Nx.to_flat_list(x_all[[.., 1]]), Nx.to_flat_list(y_raw)],
    fn [x, y, label] ->
      %{"x" => x, "y" => y, "group" => if(trunc(label) == 0, do: "Low Risk", else: "High Risk")}
    end
  )

Vl.new(width: 500, height: 400, title: "Raw Data — Two Groups of Applicants")
|> Vl.data_from_values(scatter_data)
|> Vl.mark(:circle, size: 40, opacity: 0.6)
|> Vl.encode_field(:x, "x", type: :quantitative, title: "Feature 1 (e.g. Age)")
|> Vl.encode_field(:y, "y", type: :quantitative, title: "Feature 2 (e.g. Claims)")
|> Vl.encode_field(:color, "group", type: :nominal)

Now let’s build and train a model. Edifice creates the backbone (the pattern-learning layers), and we add a head (the final prediction layer):

# Build a small MLP (multi-layer perceptron) — the simplest neural network
# It's like stacking several regression steps on top of each other
backbone = Edifice.build(:mlp, input_size: 2, hidden_sizes: [32, 16], activation: :relu, dropout: 0.0)

# Add a 2-class classifier on top
model =
  backbone
  |> Axon.dense(2, name: "output")
  |> Axon.activation(:softmax)

IO.puts("Training a simple classifier (10 epochs)...")

# Train: the model adjusts its internal numbers to reduce prediction errors
trained_state =
  model
  |> Axon.Loop.trainer(
    :categorical_cross_entropy,
    Polaris.Optimizers.adam(learning_rate: 1.0e-2)
  )
  |> Axon.Loop.metric(:accuracy)
  # 10 epochs is enough to see convergence; increase to 30+ for tighter fit
  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 10)

IO.puts("Training complete!")

Let’s check accuracy and visualize what the model learned — the decision boundary (the line it draws between the two groups):

What to look for: The colored background shows what the model would predict at every point in the space. The dots are real test data. If the model learned well, the background colors should match the dot colors. A few mismatches near the border are normal — that’s where the groups overlap.

# Evaluate on test data
{_init_fn, predict_fn} = Axon.build(model)
test_preds = predict_fn.(trained_state, test_x)
predicted_classes = Nx.argmax(test_preds, axis: 1)
true_classes = Nx.argmax(test_y, axis: 1)

accuracy =
  Nx.equal(predicted_classes, true_classes)
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("Test accuracy: #{Float.round(accuracy * 100, 1)}%")

# Create a grid to visualize the decision boundary
resolution = 60

grid_points =
  for gx <- 0..(resolution - 1), gy <- 0..(resolution - 1) do
    [-4.0 + 8.0 * gx / (resolution - 1), -4.0 + 8.0 * gy / (resolution - 1)]
  end

grid_tensor = Nx.tensor(grid_points)
grid_preds = predict_fn.(trained_state, grid_tensor)
grid_classes = Nx.argmax(grid_preds, axis: 1) |> Nx.to_flat_list()

grid_data =
  Enum.zip_with([grid_points, grid_classes], fn [[x, y], class] ->
    %{"x" => x, "y" => y, "group" => if(trunc(class) == 0, do: "Low Risk", else: "High Risk")}
  end)

test_chart_data =
  Enum.zip_with(
    [Nx.to_flat_list(test_x[[.., 0]]), Nx.to_flat_list(test_x[[.., 1]]), Nx.to_flat_list(test_labels)],
    fn [x, y, l] ->
      %{"x" => x, "y" => y, "group" => if(trunc(l) == 0, do: "Low Risk", else: "High Risk")}
    end
  )

Vl.new(width: 500, height: 400, title: "Decision Boundary (#{Float.round(accuracy * 100, 1)}% accuracy)")
|> Vl.layers([
  Vl.new()
  |> Vl.data_from_values(grid_data)
  |> Vl.mark(:square, size: 30, opacity: 0.25)
  |> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
  |> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
  |> Vl.encode_field(:color, "group", type: :nominal),
  Vl.new()
  |> Vl.data_from_values(test_chart_data)
  |> Vl.mark(:circle, size: 50, stroke: "black", stroke_width: 1)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Vl.encode_field(:color, "group", type: :nominal)
])

Key takeaway: This is logistic regression’s more flexible cousin. A logistic regression would draw a straight line; the neural network can draw a curved one. With two well-separated clusters like this, both approaches work fine — the neural network really shines when the boundary between groups is complex.

3. “Can It Handle a Spreadsheet?” — Tabular Data

Real-world data usually comes in tables: rows of records, columns of features. Think of an insurance database — each row is a policyholder, each column is something you know about them (age, years as customer, number of claims, coverage amount).

TabNet is an architecture designed specifically for this kind of data. Its key innovation: instead of treating every column equally (like a plain neural network), TabNet learns to focus on the columns that matter most for each prediction. One applicant’s risk might depend mainly on their claim history, while another’s depends on their coverage amount.

This is particularly valuable in regulated industries like insurance: if you can show which features drove each prediction, you can explain the model’s reasoning to regulators, auditors, or customers.

# Generate synthetic insurance-like data
# 5 features: age, years_as_customer, num_prior_claims, coverage_amount, region_code
key = Nx.Random.key(123)
n_samples = 1200

IO.puts("Generating synthetic insurance data (#{n_samples} policyholders)...")

# Feature 1: Age (normalized to 0-1 range, representing 18-80)
{age, key} = Nx.Random.uniform(key, shape: {n_samples, 1})

# Feature 2: Years as customer (correlated with age)
{noise, key} = Nx.Random.uniform(key, shape: {n_samples, 1})
tenure = Nx.min(Nx.multiply(age, 0.6) |> Nx.add(Nx.multiply(noise, 0.4)), 1.0)

# Feature 3: Prior claims count (0-1 scale)
{claims, key} = Nx.Random.uniform(key, shape: {n_samples, 1})

# Feature 4: Coverage amount (0-1 scale)
{coverage, key} = Nx.Random.uniform(key, shape: {n_samples, 1})

# Feature 5: Region code (categorical encoded as 0-1)
{region, key} = Nx.Random.uniform(key, shape: {n_samples, 1})

# Stack into a feature matrix: {n_samples, 5}
features = Nx.concatenate([age, tenure, claims, coverage, region], axis: 1)

# Risk classification rule (the "true" formula — the model must discover this):
# High risk if: many claims AND high coverage (the two most important features)
# Medium risk if: many claims OR high coverage
# Low risk otherwise
# This creates a non-trivial pattern where claims and coverage interact

claims_flat = Nx.squeeze(claims)
coverage_flat = Nx.squeeze(coverage)

is_high = Nx.logical_and(Nx.greater(claims_flat, 0.6), Nx.greater(coverage_flat, 0.6))
is_medium_claims = Nx.logical_and(Nx.greater(claims_flat, 0.5), Nx.logical_not(is_high))
is_medium_coverage = Nx.logical_and(Nx.greater(coverage_flat, 0.7), Nx.logical_not(is_high))
is_medium = Nx.logical_or(is_medium_claims, is_medium_coverage)

# Labels: 0=low, 1=medium, 2=high
risk_labels =
  Nx.select(is_high, 2, Nx.select(is_medium, 1, 0))

# Add some label noise (10%) — real data is never perfectly clean
{noise_mask, _key} = Nx.Random.uniform(Nx.Random.key(77), shape: {n_samples})
noisy_labels =
  Nx.select(
    Nx.less(noise_mask, 0.1),
    Nx.remainder(Nx.add(risk_labels, 1), 3),
    risk_labels
  )

y_onehot_tab =
  Nx.equal(Nx.new_axis(noisy_labels, 1), Nx.tensor([[0, 1, 2]]))
  |> Nx.as_type(:f32)

# Split and batch
n_train_tab = round(n_samples * 0.8)

train_x_tab = features[0..(n_train_tab - 1)]
train_y_tab = y_onehot_tab[0..(n_train_tab - 1)]
test_x_tab = features[n_train_tab..-1//1]
test_y_tab = y_onehot_tab[n_train_tab..-1//1]
test_labels_tab = noisy_labels[n_train_tab..-1//1]

train_data_tab =
  Enum.zip(
    Nx.to_batched(train_x_tab, 32) |> Enum.to_list(),
    Nx.to_batched(train_y_tab, 32) |> Enum.to_list()
  )

IO.puts("Features: age, tenure, claims, coverage, region")
IO.puts("Classes: Low Risk, Medium Risk, High Risk")
IO.puts("Ready: #{n_train_tab} train / #{n_samples - n_train_tab} test")

Now let’s train both a plain MLP and a TabNet on the same data to compare:

IO.puts("Training MLP on insurance data...")

mlp_tab =
  Edifice.build(:mlp, input_size: 5, hidden_sizes: [32, 16], activation: :relu, dropout: 0.0)
  |> Axon.dense(3, name: "mlp_tab_output")
  |> Axon.activation(:softmax)

mlp_tab_state =
  mlp_tab
  |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 1.0e-2))
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data_tab, Axon.ModelState.empty(), epochs: 10)

IO.puts("\nTraining TabNet on insurance data...")

tabnet_tab =
  Edifice.build(:tabnet, input_size: 5, hidden_size: 16, num_steps: 3, dropout: 0.0)
  |> Axon.dense(3, name: "tabnet_tab_output")
  |> Axon.activation(:softmax)

tabnet_tab_state =
  tabnet_tab
  |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 1.0e-2))
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data_tab, Axon.ModelState.empty(), epochs: 10)

# Evaluate both
{_, mlp_tab_pred_fn} = Axon.build(mlp_tab)
{_, tabnet_tab_pred_fn} = Axon.build(tabnet_tab)

mlp_tab_preds = mlp_tab_pred_fn.(mlp_tab_state, test_x_tab)
tabnet_tab_preds = tabnet_tab_pred_fn.(tabnet_tab_state, test_x_tab)

mlp_tab_acc =
  Nx.equal(Nx.argmax(mlp_tab_preds, axis: 1), Nx.argmax(test_y_tab, axis: 1))
  |> Nx.mean()
  |> Nx.to_number()

tabnet_tab_acc =
  Nx.equal(Nx.argmax(tabnet_tab_preds, axis: 1), Nx.argmax(test_y_tab, axis: 1))
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("\n" <> String.duplicate("=", 40))
IO.puts("  MLP accuracy:    #{Float.round(mlp_tab_acc * 100, 1)}%")
IO.puts("  TabNet accuracy: #{Float.round(tabnet_tab_acc * 100, 1)}%")
IO.puts(String.duplicate("=", 40))
IO.puts("\nBoth see the same data and train the same way — the only")
IO.puts("difference is the backbone architecture.")

Let’s visualize how the two models weight the features. We do this by measuring how much each feature affects the model’s predictions — we set each feature to its mean value one at a time and see how much the predictions change. A feature that causes a big change when flattened is an important one.

What to look for: The “true” rule depends mainly on claims and coverage. A good model should assign high importance to those two features and lower importance to the others (age, tenure, region are distractors).

IO.puts("Computing feature importance via mean-replacement ablation...")

feature_names = ["Age", "Tenure", "Claims", "Coverage", "Region"]

# Compute importance: for each feature, replace it with its mean value
# and measure how much the predictions change
compute_importance = fn predict_fn, state, test_data ->
  base_preds = predict_fn.(state, test_data)

  Enum.map(0..4, fn feat_idx ->
    # Replace this feature with its column mean
    col_mean = Nx.mean(test_data[[.., feat_idx]]) |> Nx.to_number()
    modified = Nx.put_slice(test_data, [0, feat_idx],
      Nx.broadcast(col_mean, {Nx.axis_size(test_data, 0), 1}))
    modified_preds = predict_fn.(state, modified)

    # How much did predictions change? (mean absolute difference)
    Nx.subtract(base_preds, modified_preds)
    |> Nx.abs()
    |> Nx.mean()
    |> Nx.to_number()
  end)
end

mlp_importance = compute_importance.(mlp_tab_pred_fn, mlp_tab_state, test_x_tab)
tabnet_importance = compute_importance.(tabnet_tab_pred_fn, tabnet_tab_state, test_x_tab)

# Normalize to percentages
mlp_total = Enum.sum(mlp_importance)
tabnet_total = Enum.sum(tabnet_importance)

chart_data =
  Enum.flat_map(Enum.zip([feature_names, mlp_importance, tabnet_importance]), fn {name, mlp_imp, tab_imp} ->
    [
      %{"Feature" => name, "Importance" => if(mlp_total > 0, do: mlp_imp / mlp_total * 100, else: 0), "Model" => "MLP"},
      %{"Feature" => name, "Importance" => if(tabnet_total > 0, do: tab_imp / tabnet_total * 100, else: 0), "Model" => "TabNet"}
    ]
  end)

Vl.new(width: 500, height: 300, title: "Feature Importance: Which Columns Drive Predictions?")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "Feature", type: :nominal, sort: feature_names, axis: %{label_angle: 0})
|> Vl.encode_field(:y, "Importance", type: :quantitative, title: "Relative Importance (%)")
|> Vl.encode_field(:color, "Model", type: :nominal)
|> Vl.encode_field(:x_offset, "Model", type: :nominal)

Key takeaway: TabNet is designed for exactly this kind of columnar data — it treats columns as discrete features and learns which ones to pay attention to. In practice (with dozens or hundreds of columns), TabNet’s ability to explain which columns drove each prediction is what makes it valuable in regulated industries. An MLP treats all inputs identically and doesn’t have this built-in feature selection.

4. “How Sure Are You?” — Uncertainty Estimation

Here’s a principle that good statisticians and actuaries already know: a prediction without a confidence interval is just a guess. When a model says “this applicant is high risk,” the natural follow-up is “how confident are you?”

Standard neural networks give you a single answer and no indication of confidence. MC Dropout (Monte Carlo Dropout) is a clever technique that makes a model express uncertainty. The idea:

  1. Dropout randomly “turns off” some neurons during training (a regularization trick)
  2. Normally you turn dropout off for predictions — you want deterministic answers
  3. MC Dropout leaves dropout on during prediction and runs the same input through the model many times (say 30 times)
  4. Each time, different neurons are turned off, giving a slightly different prediction
  5. If all 30 runs agree → the model is confident
  6. If the 30 runs disagree → the model is uncertain

It’s like asking 30 slightly different experts the same question and seeing whether they agree.

# Reuse the insurance data from section 3
IO.puts("Building MC Dropout model...")

mc_model = Edifice.build(:mc_dropout,
  input_size: 5,
  hidden_sizes: [64, 32],
  output_size: 3,
  dropout_rate: 0.2,
  activation: :relu
) |> Axon.activation(:softmax)

IO.puts("Training (10 epochs)...")

mc_state =
  mc_model
  |> Axon.Loop.trainer(
    :categorical_cross_entropy,
    Polaris.Optimizers.adam(learning_rate: 1.0e-2)
  )
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data_tab, Axon.ModelState.empty(), epochs: 10)

IO.puts("Training complete!")

Now let’s run 30 forward passes and see where the model is confident vs uncertain:

What to look for: The error bars (spread) should be small for “easy” predictions where the model has strong evidence, and large for borderline cases. Points near the decision boundary between risk classes should have wider error bars — the model is saying “I could go either way.”

IO.puts("Running 30 stochastic forward passes for uncertainty estimation...")

# Run multiple passes with dropout still active
num_passes = 30

{mean_preds, variance} =
  Edifice.Probabilistic.MCDropout.predict_with_uncertainty(
    mc_model, mc_state, test_x_tab, num_samples: num_passes
  )

# Get predicted class and uncertainty for each test sample
predicted = Nx.argmax(mean_preds, axis: 1) |> Nx.to_flat_list()
# Variance across passes — high variance = model disagrees with itself
uncertainty = Nx.sum(variance, axes: [1]) |> Nx.to_flat_list()

# Sort by uncertainty to show a spectrum from confident to unsure
indexed = Enum.with_index(Enum.zip([predicted, uncertainty]))
sorted = Enum.sort_by(indexed, fn {{_pred, unc}, _idx} -> unc end)

# Take 60 samples across the spectrum for a clear visualization
n_show = min(60, length(sorted))
step = max(div(length(sorted), n_show), 1)
samples = Enum.take_every(sorted, step) |> Enum.take(n_show)

class_names = ["Low", "Medium", "High"]

max_unc = Enum.max(uncertainty)
scale = if max_unc > 0, do: 1.0 / max_unc, else: 1.0

chart_data =
  Enum.with_index(samples)
  |> Enum.map(fn {{{pred, unc}, _orig_idx}, rank} ->
    %{
      "Sample" => rank,
      "Relative Uncertainty" => unc * scale * 100,
      "Predicted Risk" => Enum.at(class_names, pred)
    }
  end)

Vl.new(width: 600, height: 300,
  title: "Model Uncertainty Per Prediction (sorted low → high)")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "Sample", type: :ordinal, title: "Sample (sorted by uncertainty)")
|> Vl.encode_field(:y, "Relative Uncertainty", type: :quantitative, title: "Relative Uncertainty")
|> Vl.encode_field(:color, "Predicted Risk", type: :nominal)
# Show some concrete examples
IO.puts("Examples of confident vs uncertain predictions:\n")
IO.puts("CONFIDENT (low uncertainty):")

confident = Enum.take(sorted, 3)

for {{pred, unc}, orig_idx} <- confident do
  features_list = Nx.to_flat_list(test_x_tab[orig_idx])
  true_label = Nx.to_number(test_labels_tab[orig_idx])
  IO.puts("  Predicted: #{Enum.at(class_names, pred)} risk (uncertainty: #{Float.round(unc, 4)})")
  IO.puts("  True:      #{Enum.at(class_names, true_label)} risk")
  IO.puts("  Features:  age=#{Float.round(Enum.at(features_list, 0), 2)} claims=#{Float.round(Enum.at(features_list, 2), 2)} coverage=#{Float.round(Enum.at(features_list, 3), 2)}")
  IO.puts("")
end

IO.puts("UNCERTAIN (high uncertainty):")

uncertain = Enum.take(Enum.reverse(sorted), 3)

for {{pred, unc}, orig_idx} <- uncertain do
  features_list = Nx.to_flat_list(test_x_tab[orig_idx])
  true_label = Nx.to_number(test_labels_tab[orig_idx])
  IO.puts("  Predicted: #{Enum.at(class_names, pred)} risk (uncertainty: #{Float.round(unc, 4)})")
  IO.puts("  True:      #{Enum.at(class_names, true_label)} risk")
  IO.puts("  Features:  age=#{Float.round(Enum.at(features_list, 0), 2)} claims=#{Float.round(Enum.at(features_list, 2), 2)} coverage=#{Float.round(Enum.at(features_list, 3), 2)}")
  IO.puts("")
end

Key takeaway: A good model says “I don’t know” when it doesn’t have enough information — just like a good actuary prices uncertainty into a premium. MC Dropout gives you that ability with almost zero extra code. In practice, you might flag uncertain predictions for human review rather than auto-approving them.

5. Predicting What Happens Next — Time Series

Insurance companies care a lot about trends over time: claim frequency by month, premium growth, loss ratios quarter by quarter. Time series forecasting is about learning temporal patterns and predicting the next value.

Mamba is a recent architecture for sequence modeling. If you know ARIMA or exponential smoothing, the idea is similar — the model learns the pattern in a sequence and predicts what comes next. The difference: ARIMA requires you to specify the pattern type (trend, seasonality, differencing order); Mamba figures it out automatically from the data.

# Generate a synthetic "monthly claims" signal
# Seasonal pattern (12-month cycle) + upward trend + noise
n_points = 1500
seq_len = 16

IO.puts("Generating synthetic monthly claims data (#{n_points} months)...")

ts = Nx.iota({n_points}) |> Nx.as_type(:f32)

# Seasonal component (like claims peaking in winter)
seasonal = Nx.sin(Nx.multiply(ts, 2 * :math.pi() / 50))

# Trend (slow upward drift — like claim inflation)
trend = Nx.multiply(ts, 0.002)

# Combine
signal = Nx.add(Nx.add(seasonal, Nx.multiply(Nx.sin(Nx.multiply(ts, 0.13)), 0.3)), trend)

# Create sliding windows: each input is seq_len months, target is next month
windows = n_points - seq_len - 1

IO.puts("Creating #{windows} sliding windows (#{seq_len} months each)...")

x_seq =
  for i <- 0..(windows - 1) do
    Nx.slice(signal, [i], [seq_len]) |> Nx.reshape({seq_len, 1})
  end
  |> Nx.stack()

y_seq =
  for i <- 0..(windows - 1) do
    Nx.slice(signal, [i + seq_len], [1])
  end
  |> Nx.stack()

# Split
n_train_seq = round(windows * 0.8)

train_x_seq = x_seq[0..(n_train_seq - 1)]
train_y_seq = y_seq[0..(n_train_seq - 1)]
test_x_seq = x_seq[n_train_seq..-1//1]
test_y_seq = y_seq[n_train_seq..-1//1]

train_data_seq =
  Enum.zip(
    Nx.to_batched(train_x_seq, 32) |> Enum.to_list(),
    Nx.to_batched(train_y_seq, 32) |> Enum.to_list()
  )

IO.puts("Ready: #{n_train_seq} train / #{windows - n_train_seq} test windows")
# Plot the signal to see the pattern we're asking the model to learn
signal_list = Nx.to_flat_list(signal)

signal_chart_data =
  signal_list
  |> Enum.with_index()
  |> Enum.map(fn {val, i} ->
    split = if i < n_train_seq + seq_len, do: "Train period", else: "Test period"
    %{"Month" => i, "Claims Index" => val, "Period" => split}
  end)

Vl.new(width: 700, height: 250, title: "Synthetic Monthly Claims (seasonal + trend)")
|> Vl.data_from_values(signal_chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "Month", type: :quantitative)
|> Vl.encode_field(:y, "Claims Index", type: :quantitative)
|> Vl.encode_field(:color, "Period", type: :nominal)
IO.puts("Training Mamba model on claims data...")

mamba_model =
  Edifice.build(:mamba,
    embed_dim: 32,
    hidden_size: 16,
    state_size: 8,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.0
  )
  |> Axon.dense(1, name: "mamba_claims_output")

mamba_state =
  mamba_model
  |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 1.0e-3))
  # 3 epochs for a quick demo; increase for better fit
  |> Axon.Loop.run(train_data_seq, Axon.ModelState.empty(), epochs: 3)

# Evaluate
{_init, mamba_pred_fn} = Axon.build(mamba_model)
mamba_preds = mamba_pred_fn.(mamba_state, test_x_seq)

mamba_mse =
  Nx.subtract(mamba_preds, test_y_seq)
  |> Nx.pow(2)
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("Mamba test MSE: #{Float.round(mamba_mse, 6)}")

What to look for: The predicted line should track the actual signal closely. Small lags or amplitude differences are normal with only 3 training epochs. The model is learning the seasonal oscillation and the upward trend simultaneously.

actual_list = Nx.to_flat_list(Nx.reshape(test_y_seq, {Nx.axis_size(test_y_seq, 0)}))
pred_list = Nx.to_flat_list(Nx.reshape(mamba_preds, {Nx.axis_size(mamba_preds, 0)}))

n_test_seq = length(actual_list)

pred_chart_data =
  Enum.flat_map(0..(n_test_seq - 1), fn i ->
    [
      %{"Month" => i, "Claims Index" => Enum.at(actual_list, i), "Series" => "Actual"},
      %{"Month" => i, "Claims Index" => Enum.at(pred_list, i), "Series" => "Mamba Predicted"}
    ]
  end)

Vl.new(width: 700, height: 300, title: "Mamba: Actual vs Predicted Claims")
|> Vl.data_from_values(pred_chart_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "Month", type: :quantitative, title: "Test Month")
|> Vl.encode_field(:y, "Claims Index", type: :quantitative)
|> Vl.encode_field(:color, "Series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "Series", type: :nominal)

Key takeaway: The model learned the seasonal pattern and trend from raw data — no need to specify the period, differencing order, or trend type like you would with ARIMA. It’s automatic pattern discovery.

6. Swapping Architectures in One Line

This is one of Edifice’s core value propositions. Every architecture uses the same Edifice.build(:name, opts) API, so trying a different approach is a one-line change. The rest of your code — data prep, training loop, evaluation — stays exactly the same.

It’s like changing a formula in one cell of a spreadsheet without touching the rest of the sheet.

# Same time series data, same training loop — just swap the architecture name
IO.puts("Training GRU on the same claims data...")

gru_model =
  Edifice.build(:gru,
    embed_dim: 32,
    hidden_size: 16,
    num_layers: 2,
    seq_len: seq_len,
    window_size: seq_len,
    dropout: 0.0
  )
  |> Axon.dense(1, name: "gru_claims_output")

gru_state =
  gru_model
  |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 1.0e-3))
  |> Axon.Loop.run(train_data_seq, Axon.ModelState.empty(), epochs: 3)

{_init, gru_pred_fn} = Axon.build(gru_model)
gru_preds = gru_pred_fn.(gru_state, test_x_seq)

gru_mse =
  Nx.subtract(gru_preds, test_y_seq)
  |> Nx.pow(2)
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("GRU test MSE: #{Float.round(gru_mse, 6)}")

# Compare
IO.puts("\n" <> String.duplicate("=", 40))
IO.puts("  Mamba MSE: #{Float.round(mamba_mse, 6)}")
IO.puts("  GRU MSE:   #{Float.round(gru_mse, 6)}")
IO.puts(String.duplicate("=", 40))

best_name = if mamba_mse < gru_mse, do: "Mamba", else: "GRU"
IO.puts("\nBetter fit: #{best_name}")
IO.puts("The only code change: :mamba → :gru")

What to look for: Both models should track the signal, possibly with different strengths. Mamba and GRU are fundamentally different architectures (state-space vs recurrent), yet you train them with identical code.

gru_pred_list = Nx.to_flat_list(Nx.reshape(gru_preds, {Nx.axis_size(gru_preds, 0)}))

compare_data =
  Enum.flat_map(0..(n_test_seq - 1), fn i ->
    [
      %{"Month" => i, "Claims Index" => Enum.at(actual_list, i), "Series" => "Actual"},
      %{"Month" => i, "Claims Index" => Enum.at(pred_list, i), "Series" => "Mamba"},
      %{"Month" => i, "Claims Index" => Enum.at(gru_pred_list, i), "Series" => "GRU"}
    ]
  end)

Vl.new(width: 700, height: 300, title: "Architecture Comparison: Same Data, Different Models")
|> Vl.data_from_values(compare_data)
|> Vl.mark(:line, stroke_width: 1.5)
|> Vl.encode_field(:x, "Month", type: :quantitative, title: "Test Month")
|> Vl.encode_field(:y, "Claims Index", type: :quantitative)
|> Vl.encode_field(:color, "Series", type: :nominal)
|> Vl.encode_field(:stroke_dash, "Series", type: :nominal)

Key takeaway: Edifice lets you try many different approaches to the same problem without rewriting your code. In practice, you’d try 5-10 architectures on your data and pick the one that performs best — like an automated model selection process. The unified API makes this practical rather than a multi-week engineering effort.

7. The Architecture Zoo — A Quick Visual Tour

Edifice includes architectures for every major type of data and problem. Here’s a quick overview of the families, organized by what kind of data they’re designed for.

families = Edifice.list_families()

# Sort by number of architectures (descending)
sorted_families =
  families
  |> Enum.sort_by(fn {_name, archs} -> -length(archs) end)

total = families |> Map.values() |> List.flatten() |> length()

IO.puts("Edifice includes #{total} architectures across #{map_size(families)} families:\n")

for {family, archs} <- sorted_families do
  count = length(archs)
  names = Enum.map_join(Enum.take(archs, 5), ", ", &amp;Atom.to_string/1)
  suffix = if count > 5, do: ", ...", else: ""
  IO.puts("  #{String.pad_trailing(Atom.to_string(family), 16)} #{String.pad_leading(Integer.to_string(count), 2)} architectures  (#{names}#{suffix})")
end

Let’s build one representative from several families and compare their sizes. This shows the diversity — some architectures have a few hundred parameters, others have tens of thousands, depending on their design.

# Build a small model from each family and count parameters
defmodule ZooHelper do
  def count_params(%Axon.ModelState{} = state) do
    state
    |> Axon.ModelState.trainable_parameters()
    |> count_nested(0)
  end

  defp count_nested(%Nx.Tensor{} = t, acc), do: acc + Nx.size(t)
  defp count_nested(map, acc) when is_map(map) do
    Enum.reduce(map, acc, fn {_k, v}, a -> count_nested(v, a) end)
  end
  defp count_nested(_other, acc), do: acc

  def fmt(n) when n >= 1_000_000, do: "#{Float.round(n / 1_000_000, 1)}M"
  def fmt(n) when n >= 1_000, do: "#{Float.round(n / 1_000, 1)}K"
  def fmt(n), do: "#{n}"
end

# Small shared dimensions
batch = 2
embed = 32
hidden = 16
seq_len = 8
rand = fn shape -> elem(Nx.Random.normal(Nx.Random.key(42), shape: shape), 0) end

seq_opts = [
  embed_dim: embed, hidden_size: hidden, state_size: 8,
  num_layers: 2, seq_len: seq_len, window_size: seq_len,
  head_dim: 8, num_heads: 2, dropout: 0.0
]

zoo_specs = [
  {"MLP (feedforward)", fn -> Edifice.build(:mlp, input_size: embed, hidden_sizes: [hidden]) end,
    fn -> %{"input" => rand.({batch, embed})} end,
    "Tables, flat features"},
  {"TabNet (feedforward)", fn -> Edifice.build(:tabnet, input_size: embed, hidden_size: hidden, num_steps: 3) end,
    fn -> %{"input" => rand.({batch, embed})} end,
    "Tables with feature selection"},
  {"Mamba (state-space)", fn -> Edifice.build(:mamba, seq_opts) end,
    fn -> %{"state_sequence" => rand.({batch, seq_len, embed})} end,
    "Sequences, time series"},
  {"GRU (recurrent)", fn -> Edifice.build(:gru, seq_opts) end,
    fn -> %{"state_sequence" => rand.({batch, seq_len, embed})} end,
    "Sequences, time series"},
  {"RetNet (attention)", fn -> Edifice.build(:retnet, seq_opts) end,
    fn -> %{"state_sequence" => rand.({batch, seq_len, embed})} end,
    "Long sequences, text"},
  {"GCN (graph)", fn -> Edifice.build(:gcn, input_dim: 16, hidden_dim: hidden, num_classes: 4, num_layers: 2, num_heads: 2, dropout: 0.0) end,
    fn ->
      nodes = rand.({batch, 6, 16})
      adj = Nx.eye(6) |> Nx.broadcast({batch, 6, 6})
      %{"nodes" => nodes, "adjacency" => adj}
    end,
    "Networks, relationships"},
]

IO.puts("Building one model from each family...\n")

results =
  Enum.map(zoo_specs, fn {name, build_fn, input_fn, use_case} ->
    try do
      model = build_fn.()
      input = input_fn.()
      template = Map.new(input, fn {k, v} -> {k, Nx.template(Nx.shape(v), Nx.type(v))} end)
      {init_fn, predict_fn} = Axon.build(model)
      params = init_fn.(template, Axon.ModelState.empty())
      _output = predict_fn.(params, input)
      param_count = ZooHelper.count_params(params)
      IO.puts("  #{String.pad_trailing(name, 25)} #{String.pad_leading(ZooHelper.fmt(param_count), 8)} params — #{use_case}")
      %{"Architecture" => name, "Parameters" => param_count, "Use Case" => use_case}
    rescue
      e ->
        IO.puts("  #{String.pad_trailing(name, 25)} FAILED: #{Exception.message(e) |> String.slice(0, 50)}")
        nil
    end
  end)
  |> Enum.reject(&amp;is_nil/1)

:ok
# Bar chart of parameter counts
Vl.new(width: 500, height: 300, title: "Parameter Count by Architecture")
|> Vl.data_from_values(results)
|> Vl.mark(:bar)
|> Vl.encode_field(:y, "Architecture", type: :nominal, sort: "-x", title: nil)
|> Vl.encode_field(:x, "Parameters", type: :quantitative, title: "Number of Parameters")
|> Vl.encode_field(:color, "Architecture", type: :nominal, legend: nil)

Key takeaway: Each architecture family is designed for a different kind of data. Feedforward models handle flat tables, recurrent/SSM models handle sequences, graph models handle network-structured data. Some are tiny (a few hundred parameters), others are larger — the right choice depends on your problem, not the model’s size.

8. Seeing the Decision — Boundary Comparison

Different architectures learn different decision boundaries. Let’s train three architectures on the same 2D classification data and compare the “lines” they draw through the data.

This is a visual way to understand how architectures differ — not just in accuracy numbers, but in the actual shape of their reasoning.

# Generate 3-class data (triangle pattern)
key = Nx.Random.key(55)
n_per = 250
centers = [[-1.5, -1.0], [1.5, -1.0], [0.0, 1.5]]

IO.puts("Generating 3-class data (#{n_per * 3} points)...")

{all_pts, all_labs, _key} =
  Enum.reduce(Enum.with_index(centers), {[], [], key}, fn {[cx, cy], class}, {pts, labs, k} ->
    {noise, k} = Nx.Random.normal(k, shape: {n_per, 2})
    points = Nx.add(Nx.multiply(noise, 0.6), Nx.tensor([cx, cy]))
    {pts ++ [points], labs ++ List.duplicate(class, n_per), k}
  end)

x_3c = Nx.concatenate(all_pts)
y_3c_raw = Nx.tensor(all_labs)

{shuf, _} = Nx.Random.uniform(Nx.Random.key(88), shape: {n_per * 3})
shuf_idx = Nx.argsort(shuf)
x_3c = Nx.take(x_3c, shuf_idx)
y_3c_raw = Nx.take(y_3c_raw, shuf_idx)

y_3c = Nx.equal(Nx.new_axis(y_3c_raw, 1), Nx.tensor([[0, 1, 2]])) |> Nx.as_type(:f32)

n_train_3c = round(n_per * 3 * 0.8)
train_x_3c = x_3c[0..(n_train_3c - 1)]
train_y_3c = y_3c[0..(n_train_3c - 1)]
test_x_3c = x_3c[n_train_3c..-1//1]

train_data_3c =
  Enum.zip(
    Nx.to_batched(train_x_3c, 32) |> Enum.to_list(),
    Nx.to_batched(train_y_3c, 32) |> Enum.to_list()
  )

# Train three different architectures
architectures = [
  {:mlp, "MLP (2 layers, relu)", Edifice.build(:mlp, input_size: 2, hidden_sizes: [32, 16], activation: :relu, dropout: 0.0)},
  {:tabnet, "TabNet (attention)", Edifice.build(:tabnet, input_size: 2, hidden_size: 16, num_steps: 3, dropout: 0.0)},
  {:mlp_deep, "MLP (4 layers, tanh)", Edifice.build(:mlp, input_size: 2, hidden_sizes: [64, 32, 16, 8], activation: :tanh, dropout: 0.0)}
]

# Grid for decision boundary visualization
grid_3c =
  for gx <- 0..49, gy <- 0..49 do
    [-4.0 + 8.0 * gx / 49, -4.0 + 8.0 * gy / 49]
  end

grid_tensor_3c = Nx.tensor(grid_3c)

IO.puts("Training 3 architectures...")

boundary_results =
  Enum.map(architectures, fn {name, label, backbone} ->
    IO.puts("  Training #{label}...")

    full_model =
      backbone
      |> Axon.dense(3, name: "#{name}_boundary_out")
      |> Axon.activation(:softmax)

    state =
      full_model
      |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 1.0e-2))
      |> Axon.Loop.metric(:accuracy)
      |> Axon.Loop.run(train_data_3c, Axon.ModelState.empty(), epochs: 10)

    {_, pred_fn} = Axon.build(full_model)

    # Test accuracy
    test_preds_3c = pred_fn.(state, test_x_3c)
    test_y_3c = y_3c[n_train_3c..-1//1]
    acc = Nx.equal(Nx.argmax(test_preds_3c, axis: 1), Nx.argmax(test_y_3c, axis: 1))
      |> Nx.mean() |> Nx.to_number()

    # Grid predictions
    grid_preds = pred_fn.(state, grid_tensor_3c)
    grid_cls = Nx.argmax(grid_preds, axis: 1) |> Nx.to_flat_list()

    IO.puts("    → #{Float.round(acc * 100, 1)}% accuracy")

    {label, acc, grid_cls}
  end)

IO.puts("\nDone! Visualizing decision boundaries...")

What to look for: Each architecture draws different boundary shapes. The shallow relu MLP draws simple smooth curves. TabNet may draw sharper, axis-aligned boundaries (its attention mechanism selects features). The deeper tanh MLP has more capacity and a different activation function, so its boundaries may be smoother or more complex. All should classify well, but the shape of their reasoning differs.

# Plot all three decision boundaries side by side
class_names_3c = ["A", "B", "C"]

test_data_3c =
  Enum.zip_with(
    [Nx.to_flat_list(x_3c[[.., 0]]), Nx.to_flat_list(x_3c[[.., 1]]), Nx.to_flat_list(y_3c_raw)],
    fn [x, y, l] -> %{"x" => x, "y" => y, "class" => "Class #{Enum.at(class_names_3c, trunc(l))}"} end
  )

charts =
  Enum.map(boundary_results, fn {label, acc, grid_cls} ->
    grid_d =
      Enum.zip_with([grid_3c, grid_cls], fn [[x, y], cls] ->
        %{"x" => x, "y" => y, "class" => "Class #{Enum.at(class_names_3c, trunc(cls))}"}
      end)

    Vl.new(width: 200, height: 200, title: "#{label} (#{Float.round(acc * 100, 1)}%)")
    |> Vl.layers([
      Vl.new()
      |> Vl.data_from_values(grid_d)
      |> Vl.mark(:square, size: 15, opacity: 0.25)
      |> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
      |> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
      |> Vl.encode_field(:color, "class", type: :nominal),
      Vl.new()
      |> Vl.data_from_values(test_data_3c)
      |> Vl.mark(:circle, size: 20, stroke: "black", stroke_width: 0.5)
      |> Vl.encode_field(:x, "x", type: :quantitative)
      |> Vl.encode_field(:y, "y", type: :quantitative)
      |> Vl.encode_field(:color, "class", type: :nominal)
    ])
  end)

Vl.new()
|> Vl.concat(charts, :horizontal)

Key takeaway: Architecture choice isn’t just about accuracy — it’s about how the model reasons. Even within the same architecture family, changing the depth, width, or activation function produces different decision surfaces. Being able to quickly compare them visually helps you understand which approach suits your problem.

9. Making New Data — Generative Models

Sometimes you don’t just want to classify data — you want to create new data that looks realistic. This is valuable for:

  • Stress testing: “What would our portfolio look like under extreme scenarios?”
  • Simulation: “Generate 10,000 synthetic policyholders for testing our systems”
  • Data augmentation: “We have 500 fraud cases — can we create realistic synthetic examples to help our fraud detector learn?”

A Variational Autoencoder (VAE) learns the shape and structure of your data, then generates new samples that follow the same patterns. It’s like learning the distribution well enough to sample from it.

# Generate crescent-moon data — a classic test shape
IO.puts("Generating crescent moon dataset...")

n_pts = 1500
key = Nx.Random.key(42)

{angles_u, key} = Nx.Random.uniform(key, shape: {n_pts})
angles_u = Nx.multiply(angles_u, :math.pi())
{noise_u, key} = Nx.Random.normal(key, shape: {n_pts, 2})

upper_x = Nx.add(Nx.cos(angles_u), Nx.multiply(noise_u[[.., 0]], 0.1))
upper_y = Nx.add(Nx.sin(angles_u), Nx.multiply(noise_u[[.., 1]], 0.1))
upper = Nx.stack([upper_x, upper_y], axis: 1)

{angles_l, key} = Nx.Random.uniform(key, shape: {n_pts})
angles_l = Nx.multiply(angles_l, :math.pi())
{noise_l, _key} = Nx.Random.normal(key, shape: {n_pts, 2})

lower_x = Nx.add(Nx.subtract(1.0, Nx.cos(angles_l)), Nx.multiply(noise_l[[.., 0]], 0.1))
lower_y = Nx.subtract(Nx.subtract(0.0, Nx.sin(angles_l)), Nx.add(0.5, Nx.multiply(noise_l[[.., 1]], 0.1)))
lower = Nx.stack([lower_x, lower_y], axis: 1)

vae_data = Nx.concatenate([upper, lower])
n_vae = Nx.axis_size(vae_data, 0)

# Shuffle
{vae_shuf, _} = Nx.Random.uniform(Nx.Random.key(77), shape: {n_vae})
vae_data = Nx.take(vae_data, Nx.argsort(vae_shuf))

# For VAE, input = target (it learns to reconstruct its own input)
vae_batches =
  Nx.to_batched(vae_data, 64)
  |> Enum.to_list()
  |> Enum.map(fn batch -> {batch, batch} end)

IO.puts("Ready: #{n_vae} points for VAE training")
IO.puts("Building VAE (2D data → 2D latent space → 2D reconstruction)...")

latent_size = 2
input_size = 2

# Encoder: compresses data into a compact representation
input = Axon.input("input", shape: {nil, input_size})

enc =
  input
  |> Axon.dense(64, name: "enc_0") |> Axon.activation(:relu)
  |> Axon.dense(32, name: "enc_1") |> Axon.activation(:relu)

# The encoder outputs two things:
# - mu: the "center" of where this data point maps in the compressed space
# - log_var: how "spread out" the encoding is (uncertainty about the location)
mu = Axon.dense(enc, latent_size, name: "mu")
log_var = Axon.dense(enc, latent_size, name: "log_var")

# Decoder: reconstructs data from the compressed representation
recon =
  mu
  |> Axon.dense(32, name: "dec_0") |> Axon.activation(:relu)
  |> Axon.dense(64, name: "dec_1") |> Axon.activation(:relu)
  |> Axon.dense(input_size, name: "dec_out")

# Combined model outputs all three pieces
vae_model = Axon.container(%{reconstruction: recon, mu: mu, log_var: log_var})

# Separate encoder model (same layer names = shares trained weights)
encoder_model =
  Axon.input("input", shape: {nil, input_size})
  |> Axon.dense(64, name: "enc_0") |> Axon.activation(:relu)
  |> Axon.dense(32, name: "enc_1") |> Axon.activation(:relu)
  |> Axon.dense(latent_size, name: "mu")

IO.puts("  Encoder: 2D data → 2D latent code")
IO.puts("  Decoder: 2D latent code → 2D reconstruction")
# Train the VAE with a two-part loss:
# 1. Reconstruction loss: "make the output match the input"
# 2. KL divergence: "keep the compressed representation organized"
beta = 0.5
epochs = 30

vae_loss = fn y_true, y_pred ->
  recon_loss = Nx.mean(Nx.pow(Nx.subtract(y_pred.reconstruction, y_true), 2))

  kl =
    Nx.subtract(
      Nx.add(1.0, y_pred.log_var),
      Nx.add(Nx.pow(y_pred.mu, 2), Nx.exp(y_pred.log_var))
    )
    |> Nx.sum(axes: [-1])
    |> Nx.mean()
    |> Nx.multiply(-0.5)

  Nx.add(recon_loss, Nx.multiply(beta, kl))
end

IO.puts("Training VAE (#{epochs} epochs — this takes a moment)...")

vae_state =
  vae_model
  |> Axon.Loop.trainer(vae_loss, Polaris.Optimizers.adam(learning_rate: 3.0e-3))
  |> Axon.Loop.run(vae_batches, Axon.ModelState.empty(), epochs: epochs)

IO.puts("Training complete!")

Now the payoff — let’s generate brand new data that the model has never seen, and compare it to the real data:

What to look for: The generated points (orange) should follow the same crescent-moon shapes as the real data (blue). They won’t be identical — they’re new synthetic samples. If they form the same curved pattern, the model has learned the data’s underlying structure.

IO.puts("Generating new samples from learned distribution...")

# Encode all real data to learn the latent distribution
{_enc_init, enc_fn} = Axon.build(encoder_model)
mu_all = enc_fn.(vae_state, vae_data)

# Build standalone decoder
latent_input = Axon.input("latent", shape: {nil, latent_size})

decoder =
  latent_input
  |> Axon.dense(32, name: "dec_0") |> Axon.activation(:relu)
  |> Axon.dense(64, name: "dec_1") |> Axon.activation(:relu)
  |> Axon.dense(input_size, name: "dec_out")

{_dec_init, dec_fn} = Axon.build(decoder)

# Sample from the learned distribution and decode
n_gen = 500
latent_mean = Nx.mean(mu_all, axes: [0])
latent_std = Nx.standard_deviation(mu_all, axes: [0])

{z_noise, _} = Nx.Random.normal(Nx.Random.key(456), shape: {n_gen, latent_size})
z_samples = Nx.add(latent_mean, Nx.multiply(latent_std, z_noise))
generated = dec_fn.(vae_state, %{"latent" => z_samples})

IO.puts("Generated #{n_gen} new samples!")

# Prepare visualization data
gen_chart =
  Enum.zip_with(
    [Nx.to_flat_list(generated[[.., 0]]), Nx.to_flat_list(generated[[.., 1]])],
    fn [x, y] -> %{"x" => x, "y" => y, "source" => "Generated"} end
  )

real_chart =
  Enum.zip_with(
    [Nx.to_flat_list(vae_data[[.., 0]]), Nx.to_flat_list(vae_data[[.., 1]])],
    fn [x, y] -> %{"x" => x, "y" => y, "source" => "Real"} end
  )

Vl.new(width: 500, height: 350, title: "Real vs Generated Samples")
|> Vl.data_from_values(real_chart ++ gen_chart)
|> Vl.mark(:circle, size: 15, opacity: 0.4)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "source", type: :nominal)

Key takeaway: The model learned the shape of the data well enough to create new realistic samples. In insurance, this technique is useful for generating synthetic scenarios for stress testing (“what if we had 3x the usual catastrophe claims?”) or for creating test data that has the same statistical properties as real data without exposing actual customer information.

10. Key Takeaways

Here’s a summary of what we’ve covered:

Neural networks are sophisticated pattern matching. They’re the same fundamental idea as regression — finding patterns in data and using them to make predictions. The difference is flexibility: neural networks can learn arbitrarily complex patterns without you having to specify the formula in advance.

Different architectures for different data shapes. Flat tables → MLP, TabNet. Time series → Mamba, GRU. Network/relationship data → GCN, GAT. Images → ViT, ConvNet. The architecture encodes your assumptions about the data’s structure. Choosing the right one is like choosing the right statistical test.

Uncertainty matters. A prediction without a confidence measure isn’t very useful for decision making. MC Dropout gives you uncertainty estimates cheaply — letting the model say “I’m not sure about this one” rather than giving a falsely confident answer. In regulated industries, this isn’t optional.

Edifice makes experimentation practical. With 100+ architectures behind a single API (Edifice.build(:name, opts)), trying a different approach is a one-line change. This turns model selection from a multi-week engineering project into something you can explore in an afternoon.

Generative models create realistic synthetic data. VAEs and other generative architectures learn the structure of your data well enough to produce new samples. Useful for stress testing, simulation, data augmentation, and creating privacy-safe test data.

11. Glossary

Quick reference for terms used in this notebook.

Term Plain English
Model / Architecture A specific structure for learning patterns. Like choosing between linear regression, logistic regression, or a decision tree — each has different strengths.
Training The process of showing the model many examples so it can adjust its internal numbers to make better predictions. Like fitting a regression line to data points.
Epoch One complete pass through all the training data. If you have 1000 examples and train for 10 epochs, the model sees each example 10 times.
Loss A number measuring how wrong the model’s predictions are. Training tries to make this smaller. Analogous to the sum of squared residuals in regression.
Accuracy Percentage of predictions that match the true answer. Simple but sometimes misleading — a model that always predicts “no fraud” is 99% accurate if fraud is rare.
Batch A subset of training data processed at once. Instead of updating after every single example (slow) or after all examples (memory-intensive), we compromise with batches of 32-64.
Parameters / Weights The internal numbers the model adjusts during training. A linear regression has 2 parameters (slope and intercept). A neural network might have thousands.
Backbone The main body of a model that extracts features/patterns from input data. In Edifice, this is what Edifice.build creates.
Head A small layer added on top of the backbone for a specific task (e.g., classification, regression). Like the final output formula.
Overfitting When a model memorizes the training data instead of learning generalizable patterns. Like a student who memorizes exam answers without understanding the material.
Dropout Randomly disabling some neurons during training to prevent overfitting. Forces the model to be robust — like training with a randomly different team each day.
Latent space A compressed representation the model learns internally. A VAE compresses 2D data into a 2D latent space — but in practice, you might compress 100 features into 10 latent dimensions.
MLP Multi-Layer Perceptron. The simplest neural network — stacked layers of weighted sums with non-linear activations. The neural network equivalent of polynomial regression.
SSM State-Space Model (e.g., Mamba). Processes sequences by maintaining and updating a hidden state. Similar in spirit to ARIMA or Kalman filters.
Recurrent Models (e.g., GRU, LSTM) that process sequences one step at a time, maintaining a “memory” of what they’ve seen. Like reading a sentence word by word.
Attention A mechanism that lets models focus on relevant parts of the input. TabNet uses attention to select important features; Transformers use it to relate parts of a sequence.
KL Divergence A measure of how different two probability distributions are. In VAEs, it keeps the model’s internal representation organized and usable for generation.