Powered by AppSignal & Oban Pro

Training an MLP with Edifice

notebooks/training_mlp.livemd

Training an MLP with Edifice

Setup

Choose one of the two cells below depending on how you started Livebook.

Standalone (default)

Use this if you started Livebook normally (livebook server). Uncomment the EXLA lines for GPU acceleration.

edifice_dep =
  if File.dir?(Path.expand("~/edifice")) do
    {:edifice, path: Path.expand("~/edifice")}
  else
    {:edifice, "~> 0.2.0"}
  end

Mix.install([
  edifice_dep,
  # {:exla, "~> 0.10"},
  {:kino_vega_lite, "~> 0.1"},
  {:pythonx, "~> 0.4.2"},
  {:kino_pythonx, "~> 0.1.0"}
])

# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl

Attached to project (recommended for Nix/CUDA)

Use this if you started Livebook via ./scripts/livebook.sh. See the Architecture Zoo notebook for full setup instructions.

Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")

Introduction

This notebook demonstrates end-to-end neural network training using Edifice for model architecture, Axon for the training loop, and Polaris for optimization.

What you’ll learn:

  • How Edifice models compose with Axon’s training API
  • Generating and batching data for training
  • Building a training loop with loss, optimizer, and metrics
  • Evaluating and visualizing a trained classifier
  • Swapping architectures with one line of code

No external datasets required — we generate synthetic data so the notebook runs instantly with zero extra dependencies.

Generate Synthetic Data

We create a 3-class classification problem: three Gaussian clusters arranged in a triangle pattern. Each cluster has 300 points in 2D space.

key = Nx.Random.key(42)
n_per_class = 300

# Three cluster centers forming a triangle
centers = [[-1.5, -1.0], [1.5, -1.0], [0.0, 1.5]]

IO.puts("Generating #{n_per_class * length(centers)} points across #{length(centers)} classes...")

# Generate points with Gaussian noise (std=0.6) around each center
{all_points, all_labels, _key} =
  Enum.reduce(Enum.with_index(centers), {[], [], key}, fn {[cx, cy], class}, {pts, labs, k} ->
    {noise, k} = Nx.Random.normal(k, shape: {n_per_class, 2})
    points = Nx.add(Nx.multiply(noise, 0.6), Nx.tensor([cx, cy]))
    {pts ++ [points], labs ++ List.duplicate(class, n_per_class), k}
  end)

x_all = Nx.concatenate(all_points)
y_all = Nx.tensor(all_labels)

IO.puts("Shuffling and encoding...")

# Shuffle so classes aren't in order
shuffle_idx = Nx.tensor(Enum.shuffle(0..899))
x_all = Nx.take(x_all, shuffle_idx)
y_all = Nx.take(y_all, shuffle_idx)

# One-hot encode labels for cross-entropy loss: {900} -> {900, 3}
y_onehot =
  Nx.equal(Nx.new_axis(y_all, 1), Nx.tensor([[0, 1, 2]]))
  |> Nx.as_type(:f32)

# 80/20 train/test split
n_train = 720

train_x = x_all[0..(n_train - 1)]
train_y = y_onehot[0..(n_train - 1)]
test_x = x_all[n_train..-1//1]
test_y = y_onehot[n_train..-1//1]
test_labels = y_all[n_train..-1//1]

# Batch training data for the training loop
batch_size = 32

IO.puts("Batching training data...")
train_data =
  Enum.zip(
    Nx.to_batched(train_x, batch_size) |> Enum.to_list(),
    Nx.to_batched(train_y, batch_size) |> Enum.to_list()
  )

IO.puts("Ready: #{n_train} train / #{Nx.axis_size(test_x, 0)} test samples, #{length(train_data)} batches/epoch")

Let’s visualize the clusters to see what the model needs to learn:

to_chart_data = fn x_tensor, labels_tensor ->
  xs = Nx.to_flat_list(x_tensor[[.., 0]])
  ys = Nx.to_flat_list(x_tensor[[.., 1]])
  labs = Nx.to_flat_list(labels_tensor)

  Enum.zip_with([xs, ys, labs], fn [x, y, label] ->
    %{"x" => x, "y" => y, "class" => "Class #{trunc(label)}"}
  end)
end

Vl.new(width: 500, height: 400, title: "Synthetic 3-Class Dataset")
|> Vl.data_from_values(to_chart_data.(x_all, y_all))
|> Vl.mark(:circle, size: 40, opacity: 0.6)
|> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
|> Vl.encode_field(:color, "class", type: :nominal)

Build the Model

Edifice builds the backbone — the feature-extraction layers. We then pipe it into a task-specific head (here, a 3-class softmax classifier).

This is the core pattern: Edifice backbone → Axon head → Axon training.

# MLP backbone: 2 input features → 32 → 16 hidden units
backbone =
  Edifice.Feedforward.MLP.build(
    input_size: 2,
    hidden_sizes: [32, 16],
    activation: :relu,
    dropout: 0.0
  )

# Classification head: 16 → 3 classes with softmax
model =
  backbone
  |> Axon.dense(3, name: "output")
  |> Axon.activation(:softmax)

model

Train

Training in Elixir uses three composable pieces:

Piece Library Role
Model Axon (via Edifice) Defines the computation graph
Loss Axon.Losses Measures prediction error
Optimizer Polaris Updates weights via gradient descent

Axon.Loop.trainer/3 wires these together into a training loop. Axon.Loop.run/4 executes it, returning the trained parameters.

IO.puts("Training MLP (10 epochs)...")

trained_state =
  model
  |> Axon.Loop.trainer(
    :categorical_cross_entropy,
    Polaris.Optimizers.adam(learning_rate: 1.0e-2)
  )
  |> Axon.Loop.metric(:accuracy)
  # Increase to 30-50 for lower loss; 10 is enough to see convergence
  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 10)

IO.puts("Training complete!")

Evaluate

Let’s see how the trained model performs on the held-out test set.

# Build a prediction function (inference mode disables dropout)
{_init_fn, predict_fn} = Axon.build(model)

# Run predictions on all test samples at once
test_preds = predict_fn.(trained_state, test_x)

# Compare predicted vs true class
predicted_classes = Nx.argmax(test_preds, axis: 1)
true_classes = Nx.argmax(test_y, axis: 1)

accuracy =
  Nx.equal(predicted_classes, true_classes)
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("Test accuracy: #{Float.round(accuracy * 100, 1)}%")

Now let’s visualize the decision boundary — the regions where the model predicts each class. We create a grid of points, classify each one, and overlay the test data on top.

# Create a grid covering the data space
resolution = 60

grid_points =
  for gx <- 0..(resolution - 1), gy <- 0..(resolution - 1) do
    [-4.0 + 8.0 * gx / (resolution - 1), -4.0 + 8.0 * gy / (resolution - 1)]
  end

grid_tensor = Nx.tensor(grid_points)
grid_preds = predict_fn.(trained_state, grid_tensor)
grid_classes = Nx.argmax(grid_preds, axis: 1) |> Nx.to_flat_list()

grid_data =
  Enum.zip_with([grid_points, grid_classes], fn [[x, y], class] ->
    %{"x" => x, "y" => y, "class" => "Class #{trunc(class)}"}
  end)

test_chart_data = to_chart_data.(test_x, test_labels)

Vl.new(width: 500, height: 400, title: "Decision Boundary")
|> Vl.layers([
  # Background: predicted class regions
  Vl.new()
  |> Vl.data_from_values(grid_data)
  |> Vl.mark(:square, size: 30, opacity: 0.3)
  |> Vl.encode_field(:x, "x", type: :quantitative, scale: %{domain: [-4, 4]})
  |> Vl.encode_field(:y, "y", type: :quantitative, scale: %{domain: [-4, 4]})
  |> Vl.encode_field(:color, "class", type: :nominal),
  # Foreground: actual test points
  Vl.new()
  |> Vl.data_from_values(test_chart_data)
  |> Vl.mark(:circle, size: 50, stroke: "black", stroke_width: 1)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
  |> Vl.encode_field(:color, "class", type: :nominal)
])

Swap Architecture

One of Edifice’s key features is the unified registry. Every architecture uses the same Edifice.build(:name, opts) API, so swapping backbones is a one-line change while the training loop stays identical.

# These two are equivalent:
mlp_direct = Edifice.Feedforward.MLP.build(input_size: 2, hidden_sizes: [32, 16])
mlp_registry = Edifice.build(:mlp, input_size: 2, hidden_sizes: [32, 16])

IO.puts("Direct build == Registry build: same API, same result")
IO.puts("Registry has #{length(Edifice.list_architectures())} architectures available")

Let’s try TabNet — an attention-based architecture designed for tabular data. Same training loop, different backbone:

# Swap the backbone: MLP → TabNet
IO.puts("Building TabNet model...")
tabnet_model =
  Edifice.build(:tabnet, input_size: 2, hidden_size: 16, num_steps: 3)
  |> Axon.dense(3, name: "output")
  |> Axon.activation(:softmax)

IO.puts("Training TabNet (10 epochs)...")
tabnet_state =
  tabnet_model
  |> Axon.Loop.trainer(
    :categorical_cross_entropy,
    Polaris.Optimizers.adam(learning_rate: 1.0e-2)
  )
  |> Axon.Loop.metric(:accuracy)
  # Increase to 30-50 for lower loss; 10 is enough to see convergence
  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 10)

# Evaluate TabNet
{_init_fn, tabnet_predict} = Axon.build(tabnet_model)
tabnet_preds = tabnet_predict.(tabnet_state, test_x)

tabnet_acc =
  Nx.equal(Nx.argmax(tabnet_preds, axis: 1), Nx.argmax(test_y, axis: 1))
  |> Nx.mean()
  |> Nx.to_number()

IO.puts("MLP test accuracy:    #{Float.round(accuracy * 100, 1)}%")
IO.puts("TabNet test accuracy: #{Float.round(tabnet_acc * 100, 1)}%")
IO.puts("\nSame data, same training loop — just a different backbone.")

Why does TabNet underperform here? TabNet’s core innovation is sparse feature attention — at each decision step, it learns which input features to focus on. This is powerful when you have dozens or hundreds of features (e.g. medical records, financial data) and only a few matter for each prediction. But with just 2 input features, the attention mechanism has almost nothing to select between, so it adds complexity without benefit. An MLP handles low-dimensional classification trivially with simple dense layers.

This is a key insight: architecture choice depends on the problem. TabNet would outshine an MLP on a 50-feature dataset where only a handful of features are relevant per sample. The unified Edifice.build/2 API makes it easy to experiment and find the right fit.

What’s Next?

Now that you’ve seen the pattern, here are some things to explore:

  • Try other architectures: Edifice.list_architectures() shows all 90+ options. Sequence models like :mamba or :griffin expect 3D input {batch, seq_len, features} — reshape your data accordingly.

  • Add EXLA for GPU acceleration: Add {:exla, "~> 0.10"} to Mix.install and set Nx.global_default_backend(EXLA.Backend) for 10-100x speedups on real datasets.

  • Use real data: Add {:scidata, "~> 0.1"} and load MNIST, CIFAR-10, or other standard benchmarks.

  • Customize the training loop: Axon.Loop supports custom metrics, learning rate schedules, early stopping, checkpointing, and more. See the Axon guides.