Training an MLP with Edifice
Setup
Choose one of the two cells below depending on how you started Livebook.
Standalone (default)
Use this if you started Livebook normally (livebook server).
Uncomment the EXLA lines for GPU acceleration.
edifice_dep =
if File.dir?(Path.expand("~/edifice")) do
{:edifice, path: Path.expand("~/edifice")}
else
{:edifice, "~> 0.2.0"}
end
Mix.install([
edifice_dep,
# {:exla, "~> 0.10"},
{:kino_vega_lite, "~> 0.1"},
{:pythonx, "~> 0.4.2"},
{:kino_pythonx, "~> 0.1.0"}
])
# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
Attached to project (recommended for Nix/CUDA)
Use this if you started Livebook via ./scripts/livebook.sh.
See the Architecture Zoo notebook for full setup instructions.
Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")
Introduction
This notebook demonstrates end-to-end neural network training using Edifice for model architecture, Axon for the training loop, and Polaris for optimization.
What you’ll learn:
- How Edifice models compose with Axon’s training API
- Generating and batching data for training
- Building a training loop with loss, optimizer, and metrics
- Evaluating and visualizing a trained classifier
- Swapping architectures with one line of code
No external datasets required — we generate synthetic data so the notebook runs instantly with zero extra dependencies.
Generate Synthetic Data
We create a 3-class classification problem: three Gaussian clusters arranged in a triangle pattern. Each cluster has 300 points in 2D space.
key = Nx.Random.key(42)
n_per_class = 300
# Three cluster centers forming a triangle
centers = [[-1.5, -1.0], [1.5, -1.0], [0.0, 1.5]]
IO.puts("Generating #{n_per_class * length(centers)} points across #{length(centers)} classes...")
# Generate points with Gaussian noise (std=0.6) around each center
{all_points, all_labels, _key} =
Enum.reduce(Enum.with_index(centers), {[], [], key}, fn {[cx, cy], class}, {pts, labs, k} ->
{noise, k} = Nx.Random.normal(k, shape: {n_per_class, 2})
points = Nx.add(Nx.multiply(noise, 0.6), Nx.tensor([cx, cy]))
{pts ++ [points], labs ++ List.duplicate(class, n_per_class), k}
end)
x_all = Nx.concatenate(all_points)
y_all = Nx.tensor(all_labels)
IO.puts("Shuffling and encoding...")
# Shuffle so classes aren't in order
shuffle_idx = Nx.tensor(Enum.shuffle(0..899))
x_all = Nx.take(x_all, shuffle_idx)
y_all = Nx.take(y_all, shuffle_idx)
# One-hot encode labels for cross-entropy loss: {900} -> {900, 3}
y_onehot =
Nx.equal(Nx.new_axis(y_all, 1), Nx.tensor([[0, 1, 2]]))
|> Nx.as_type(:f32)
# 80/20 train/test split
n_train = 720
train_x = x_all[0..(n_train - 1)]
train_y = y_onehot[0..(n_train - 1)]
test_x = x_all[n_train..-1//1]
test_y = y_onehot[n_train..-1//1]
test_labels = y_all[n_train..-1//1]
# Batch training data for the training loop
batch_size = 32
IO.puts("Batching training data...")
train_data =
Enum.zip(
Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
Nx.to_batched(train_y, batch_size) |> Enum.to_list()
)
IO.puts("Ready: #{n_train} train / #{Nx.axis_size(test_x, 0)} test samples, #{length(train_data)} batches/epoch")
Let’s visualize the clusters to see what the model needs to learn:
to_chart_data = fn x_tensor, labels_tensor ->
xs = Nx.to_flat_list(x_tensor[[.., 0]])
ys = Nx.to_flat_list(x_tensor[[.., 1]])
labs = Nx.to_flat_list(labels_tensor)
Enum.zip_with([xs, ys, labs], fn [x, y, label] ->
%{"x" => x, "y" => y, "class" => "Class #{trunc(label)}"}
end)
end
Vl.new(width: 500, height: 400, title: "Synthetic 3-Class Dataset")
|> Vl.data_from_values(to_chart_data.(x_all, y_all))
|> Vl.mark(:circle, size: 40, opacity: 0.6)
|> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:color, "class", type: :nominal)
Build the Model
Edifice builds the backbone — the feature-extraction layers. We then pipe it into a task-specific head (here, a 3-class softmax classifier).
This is the core pattern: Edifice backbone → Axon head → Axon training.
# MLP backbone: 2 input features → 32 → 16 hidden units
backbone =
Edifice.Feedforward.MLP.build(
input_size: 2,
hidden_sizes: [32, 16],
activation: :relu,
dropout: 0.0
)
# Classification head: 16 → 3 classes with softmax
model =
backbone
|> Axon.dense(3, name: "output")
|> Axon.activation(:softmax)
model
Train
Training in Elixir uses three composable pieces:
| Piece | Library | Role |
|---|---|---|
| Model | Axon (via Edifice) | Defines the computation graph |
| Loss | Axon.Losses | Measures prediction error |
| Optimizer | Polaris | Updates weights via gradient descent |
Axon.Loop.trainer/3 wires these together into a training loop.
Axon.Loop.run/4 executes it, returning the trained parameters.
IO.puts("Training MLP (10 epochs)...")
trained_state =
model
|> Axon.Loop.trainer(
:categorical_cross_entropy,
Polaris.Optimizers.adam(learning_rate: 1.0e-2)
)
|> Axon.Loop.metric(:accuracy)
# Increase to 30-50 for lower loss; 10 is enough to see convergence
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 10)
IO.puts("Training complete!")
Evaluate
Let’s see how the trained model performs on the held-out test set.
# Build a prediction function (inference mode disables dropout)
{_init_fn, predict_fn} = Axon.build(model)
# Run predictions on all test samples at once
test_preds = predict_fn.(trained_state, test_x)
# Compare predicted vs true class
predicted_classes = Nx.argmax(test_preds, axis: 1)
true_classes = Nx.argmax(test_y, axis: 1)
accuracy =
Nx.equal(predicted_classes, true_classes)
|> Nx.mean()
|> Nx.to_number()
IO.puts("Test accuracy: #{Float.round(accuracy * 100, 1)}%")
Now let’s visualize the decision boundary — the regions where the model predicts each class. We create a grid of points, classify each one, and overlay the test data on top.
# Create a grid covering the data space
resolution = 60
grid_points =
for gx <- 0..(resolution - 1), gy <- 0..(resolution - 1) do
[-4.0 + 8.0 * gx / (resolution - 1), -4.0 + 8.0 * gy / (resolution - 1)]
end
grid_tensor = Nx.tensor(grid_points)
grid_preds = predict_fn.(trained_state, grid_tensor)
grid_classes = Nx.argmax(grid_preds, axis: 1) |> Nx.to_flat_list()
grid_data =
Enum.zip_with([grid_points, grid_classes], fn [[x, y], class] ->
%{"x" => x, "y" => y, "class" => "Class #{trunc(class)}"}
end)
test_chart_data = to_chart_data.(test_x, test_labels)
Vl.new(width: 500, height: 400, title: "Decision Boundary")
|> Vl.layers([
# Background: predicted class regions
Vl.new()
|> Vl.data_from_values(grid_data)
|> Vl.mark(:square, size: 30, opacity: 0.3)
|> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:color, "class", type: :nominal),
# Foreground: actual test points
Vl.new()
|> Vl.data_from_values(test_chart_data)
|> Vl.mark(:circle, size: 50, stroke: "black", stroke_width: 1)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "class", type: :nominal)
])
Swap Architecture
One of Edifice’s key features is the unified registry. Every architecture
uses the same Edifice.build(:name, opts) API, so swapping backbones is a
one-line change while the training loop stays identical.
# These two are equivalent:
mlp_direct = Edifice.Feedforward.MLP.build(input_size: 2, hidden_sizes: [32, 16])
mlp_registry = Edifice.build(:mlp, input_size: 2, hidden_sizes: [32, 16])
IO.puts("Direct build == Registry build: same API, same result")
IO.puts("Registry has #{length(Edifice.list_architectures())} architectures available")
Let’s try TabNet — an attention-based architecture designed for tabular data. Same training loop, different backbone:
# Swap the backbone: MLP → TabNet
IO.puts("Building TabNet model...")
tabnet_model =
Edifice.build(:tabnet, input_size: 2, hidden_size: 16, num_steps: 3)
|> Axon.dense(3, name: "output")
|> Axon.activation(:softmax)
IO.puts("Training TabNet (10 epochs)...")
tabnet_state =
tabnet_model
|> Axon.Loop.trainer(
:categorical_cross_entropy,
Polaris.Optimizers.adam(learning_rate: 1.0e-2)
)
|> Axon.Loop.metric(:accuracy)
# Increase to 30-50 for lower loss; 10 is enough to see convergence
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 10)
# Evaluate TabNet
{_init_fn, tabnet_predict} = Axon.build(tabnet_model)
tabnet_preds = tabnet_predict.(tabnet_state, test_x)
tabnet_acc =
Nx.equal(Nx.argmax(tabnet_preds, axis: 1), Nx.argmax(test_y, axis: 1))
|> Nx.mean()
|> Nx.to_number()
IO.puts("MLP test accuracy: #{Float.round(accuracy * 100, 1)}%")
IO.puts("TabNet test accuracy: #{Float.round(tabnet_acc * 100, 1)}%")
IO.puts("\nSame data, same training loop — just a different backbone.")
Why does TabNet underperform here? TabNet’s core innovation is sparse feature attention — at each decision step, it learns which input features to focus on. This is powerful when you have dozens or hundreds of features (e.g. medical records, financial data) and only a few matter for each prediction. But with just 2 input features, the attention mechanism has almost nothing to select between, so it adds complexity without benefit. An MLP handles low-dimensional classification trivially with simple dense layers.
This is a key insight: architecture choice depends on the problem. TabNet
would outshine an MLP on a 50-feature dataset where only a handful of features
are relevant per sample. The unified Edifice.build/2 API makes it easy to
experiment and find the right fit.
What’s Next?
Now that you’ve seen the pattern, here are some things to explore:
-
Try other architectures:
Edifice.list_architectures()shows all 90+ options. Sequence models like:mambaor:griffinexpect 3D input{batch, seq_len, features}— reshape your data accordingly. -
Add EXLA for GPU acceleration: Add
{:exla, "~> 0.10"}toMix.installand setNx.global_default_backend(EXLA.Backend)for 10-100x speedups on real datasets. -
Use real data: Add
{:scidata, "~> 0.1"}and load MNIST, CIFAR-10, or other standard benchmarks. -
Customize the training loop:
Axon.Loopsupports custom metrics, learning rate schedules, early stopping, checkpointing, and more. See the Axon guides.