Graph Classification with Edifice
Setup
Choose one of the two cells below depending on how you started Livebook.
Standalone (default)
Use this if you started Livebook normally (livebook server).
Uncomment the EXLA lines for GPU acceleration.
edifice_dep =
if File.dir?(Path.expand("~/edifice")) do
{:edifice, path: Path.expand("~/edifice")}
else
{:edifice, "~> 0.2.0"}
end
Mix.install([
edifice_dep,
# {:exla, "~> 0.10"},
{:kino_vega_lite, "~> 0.1"},
{:kino, "~> 0.14"}
])
# Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
Attached to project (recommended for Nix/CUDA)
Use this if you started Livebook via ./scripts/livebook.sh.
See the Architecture Zoo notebook for full setup instructions.
Nx.global_default_backend(EXLA.Backend)
alias VegaLite, as: Vl
IO.puts("Attached mode — using EXLA backend from project node")
Introduction
Graph Neural Networks (GNNs) learn from structured, relational data — nodes connected by edges. Unlike images (grids) or text (sequences), graphs have irregular topology: each node can have a different number of neighbors.
This notebook trains 3 GNN architectures from Edifice on a synthetic graph classification task: given a small graph, predict which class it belongs to.
What you’ll learn:
- How to represent graphs as node features + adjacency matrices
-
Building GCN, GAT, and GIN models with
Edifice.build/2 - Graph-level classification using global pooling
- Comparing GNN architectures on the same task
Generate Synthetic Graph Dataset
We create 3 classes of graphs that share random topologies but differ in how node features relate to graph structure. This forces the GNN to actually perform message passing — node features alone aren’t enough.
- Class 0 — “Smooth”: Connected neighbors have similar features (low variation across edges)
- Class 1 — “Noisy”: Features are random with no structural correlation
- Class 2 — “Alternating”: Connected neighbors have opposite-sign features (high variation)
All three classes use the same random graph structure, so the model can’t cheat by looking at topology alone.
num_nodes = 8
input_dim = 4
num_per_class = 100
num_classes = 3
num_graphs = num_per_class * num_classes
edge_prob = 0.4
IO.puts("Generating #{num_graphs} synthetic graphs (#{num_nodes} nodes, #{input_dim} features, #{num_classes} classes)...")
key = Nx.Random.key(42)
# Helper: generate a random connected graph adjacency matrix
make_random_adj = fn k ->
{rand, k} = Nx.Random.uniform(k, shape: {num_nodes, num_nodes})
# Make symmetric, apply threshold, zero diagonal
upper = Nx.less(rand, edge_prob) |> Nx.as_type(:f32)
sym = Nx.add(upper, Nx.transpose(upper)) |> Nx.min(1.0)
diag_mask = Nx.subtract(1.0, Nx.eye(num_nodes))
adj = Nx.multiply(sym, diag_mask)
# Ensure connectivity: add a path through all nodes
adj = Enum.reduce(0..(num_nodes - 2), adj, fn i, a ->
a
|> Nx.put_slice([i, i + 1], Nx.tensor([[1.0]]))
|> Nx.put_slice([i + 1, i], Nx.tensor([[1.0]]))
end)
{adj, k}
end
{all_nodes, all_adjs, all_labels, _key} =
Enum.reduce(0..(num_graphs - 1), {[], [], [], key}, fn i, {nodes_acc, adj_acc, lab_acc, k} ->
class = rem(i, num_classes)
# Same random topology for all classes
{adj, k} = make_random_adj.(k)
# Generate base features
{base_feats, k} = Nx.Random.normal(k, shape: {num_nodes, input_dim})
node_feats =
case class do
0 ->
# Smooth: propagate features through adjacency (like a low-pass filter)
# Neighbors end up with correlated features
degree = Nx.sum(adj, axes: [1], keep_axes: true) |> Nx.max(1.0)
smoothed = Nx.dot(adj, base_feats) |> Nx.divide(degree)
Nx.add(Nx.multiply(smoothed, 0.7), Nx.multiply(base_feats, 0.3))
1 ->
# Noisy: pure random features, no structural correlation
Nx.multiply(base_feats, 0.8)
2 ->
# Alternating: assign +/- pattern, neighbors get opposite signs
# Use a simple 2-coloring approximation
signs = Nx.iota({num_nodes, 1}) |> Nx.remainder(2) |> Nx.multiply(2) |> Nx.subtract(1) |> Nx.as_type(:f32)
Nx.multiply(Nx.abs(base_feats), signs)
end
{nodes_acc ++ [node_feats], adj_acc ++ [adj], lab_acc ++ [class], k}
end)
IO.puts(" Stacking tensors...")
nodes_tensor = Nx.stack(all_nodes)
adj_tensor = Nx.stack(all_adjs)
labels_tensor = Nx.tensor(all_labels)
# One-hot encode
y_onehot =
Nx.equal(Nx.new_axis(labels_tensor, 1), Nx.tensor([Enum.to_list(0..(num_classes - 1))]))
|> Nx.as_type(:f32)
# Shuffle
{shuffle_noise, _k} = Nx.Random.uniform(Nx.Random.key(99), shape: {num_graphs})
shuffle_idx = Nx.argsort(shuffle_noise)
nodes_tensor = Nx.take(nodes_tensor, shuffle_idx)
adj_tensor = Nx.take(adj_tensor, shuffle_idx)
y_onehot = Nx.take(y_onehot, shuffle_idx)
labels_tensor = Nx.take(labels_tensor, shuffle_idx)
# 80/20 split
n_train = round(num_graphs * 0.8)
train_nodes = nodes_tensor[0..(n_train - 1)]
train_adj = adj_tensor[0..(n_train - 1)]
train_y = y_onehot[0..(n_train - 1)]
test_nodes = nodes_tensor[n_train..-1//1]
test_adj = adj_tensor[n_train..-1//1]
test_y = y_onehot[n_train..-1//1]
test_labels = labels_tensor[n_train..-1//1]
# Batch for training
batch_size = 16
IO.puts(" Batching training data...")
train_data =
Enum.zip([
Nx.to_batched(train_nodes, batch_size) |> Enum.to_list(),
Nx.to_batched(train_adj, batch_size) |> Enum.to_list(),
Nx.to_batched(train_y, batch_size) |> Enum.to_list()
])
|> Enum.map(fn {nodes, adj, labels} ->
{%{"nodes" => nodes, "adjacency" => adj}, labels}
end)
IO.puts("Ready: #{n_train} train / #{num_graphs - n_train} test graphs, #{length(train_data)} batches/epoch")
Visualize Example Graphs
Let’s look at one example graph from each class. Nodes are arranged in a circle and edges show the random connectivity. Node color represents the mean feature value — this is what the GNN sees.
Notice:
- Smooth: Connected nodes have similar colors (correlated features)
- Noisy: Colors are random with no pattern relative to edges
- Alternating: Connected nodes tend to have opposite colors
# Pick one example from each class
examples =
Enum.map(0..(num_classes - 1), fn c ->
idx = Nx.to_flat_list(labels_tensor) |> Enum.find_index(&(trunc(&1) == c))
%{
"class" => c,
"nodes" => nodes_tensor[idx],
"adj" => adj_tensor[idx]
}
end)
class_names = %{0 => "Smooth", 1 => "Noisy", 2 => "Alternating"}
# Lay out nodes in a circle
node_positions =
Enum.map(0..(num_nodes - 1), fn i ->
angle = 2 * :math.pi() * i / num_nodes
{Float.round(:math.cos(angle), 3), Float.round(:math.sin(angle), 3)}
end)
graph_charts =
Enum.map(examples, fn ex ->
adj_list = Nx.to_flat_list(Nx.reshape(ex["adj"], {num_nodes * num_nodes}))
node_means = ex["nodes"] |> Nx.mean(axes: [1]) |> Nx.to_flat_list()
# Build edge data from adjacency matrix (upper triangle only to avoid duplicates)
edge_data =
for i <- 0..(num_nodes - 2), j <- (i + 1)..(num_nodes - 1),
Enum.at(adj_list, i * num_nodes + j) > 0.5 do
{x1, y1} = Enum.at(node_positions, i)
{x2, y2} = Enum.at(node_positions, j)
%{"x" => x1, "y" => y1, "x2" => x2, "y2" => y2}
end
# Build node data with feature coloring
node_data =
Enum.with_index(node_positions, fn {x, y}, i ->
%{"x" => x, "y" => y, "node" => "N#{i}",
"feature" => Float.round(Enum.at(node_means, i), 2)}
end)
n_edges = length(edge_data)
Vl.new(width: 220, height: 220,
title: "#{class_names[ex["class"]]} (#{n_edges} edges)")
|> Vl.layers([
Vl.new()
|> Vl.data_from_values(edge_data)
|> Vl.mark(:rule, color: "#ccc", stroke_width: 1.5)
|> Vl.encode_field(:x, "x", type: :quantitative,
scale: %{domain: [-1.5, 1.5]}, axis: nil)
|> Vl.encode_field(:y, "y", type: :quantitative,
scale: %{domain: [-1.5, 1.5]}, axis: nil)
|> Vl.encode_field(:x2, "x2")
|> Vl.encode_field(:y2, "y2"),
Vl.new()
|> Vl.data_from_values(node_data)
|> Vl.mark(:circle, size: 350, stroke: "black", stroke_width: 1)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "feature", type: :quantitative,
scale: %{scheme: "redblue", domain_mid: 0},
legend: %{title: "Feature"}),
Vl.new()
|> Vl.data_from_values(node_data)
|> Vl.mark(:text, font_size: 10, font_weight: "bold", dy: 0)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:text, "node", type: :nominal)
])
end)
Vl.new()
|> Vl.concat(graph_charts, :horizontal)
Feature Distribution per Class
The bar charts below show per-node mean feature values for the same example graphs. This makes the class differences even clearer:
feature_charts =
Enum.map(examples, fn ex ->
node_means =
ex["nodes"]
|> Nx.mean(axes: [1])
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.map(fn {val, i} -> %{"node" => "N#{i}", "feature_mean" => val} end)
Vl.new(width: 200, height: 180, title: "Class #{ex["class"]}: #{class_names[ex["class"]]}")
|> Vl.data_from_values(node_means)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "node", type: :nominal)
|> Vl.encode_field(:y, "feature_mean", type: :quantitative, title: "Mean Feature")
|> Vl.encode_field(:color, "feature_mean", type: :quantitative,
scale: %{scheme: "redblue", domain_mid: 0})
end)
Vl.new()
|> Vl.concat(feature_charts, :horizontal)
Helper: Train and Evaluate a Graph Model
GNN models output per-node features {batch, num_nodes, hidden_dim}.
For graph-level classification, we apply global mean pooling to get
a single vector per graph, then feed it through a classification head.
defmodule GraphTrainer do
def train_and_eval(model, train_data, test_nodes, test_adj, test_y, opts \\ []) do
epochs = Keyword.get(opts, :epochs, 10)
lr = Keyword.get(opts, :lr, 1.0e-2)
trained_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: lr))
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: epochs)
{_init_fn, predict_fn} = Axon.build(model)
preds = predict_fn.(trained_state, %{"nodes" => test_nodes, "adjacency" => test_adj})
accuracy =
Nx.equal(Nx.argmax(preds, axis: 1), Nx.argmax(test_y, axis: 1))
|> Nx.mean()
|> Nx.to_number()
{trained_state, preds, accuracy}
end
end
Train GCN (Graph Convolutional Network)
GCN learns node representations by aggregating features from neighbors using spectral-style convolutions with symmetric normalization.
IO.puts("Building GCN model...")
gcn_backbone =
Edifice.build(:gcn,
input_dim: input_dim,
hidden_dims: [32, 16],
activation: :relu,
dropout: 0.0
)
# Global mean pool + classification head
gcn_model =
gcn_backbone
|> Axon.nx(fn x -> Nx.mean(x, axes: [1]) end, name: "gcn_pool")
|> Axon.dense(num_classes, name: "gcn_head")
|> Axon.activation(:softmax)
# Increase to 30-50 for better accuracy; 10 is enough to see convergence
IO.puts("Training GCN (10 epochs)...")
{gcn_state, gcn_preds, gcn_acc} =
GraphTrainer.train_and_eval(gcn_model, train_data, test_nodes, test_adj, test_y, epochs: 10)
IO.puts("GCN test accuracy: #{Float.round(gcn_acc * 100, 1)}%")
Train GAT (Graph Attention Network)
GAT uses multi-head attention to learn which neighbors matter more. Instead of treating all neighbors equally (like GCN), it computes attention weights for each edge.
IO.puts("Building GAT model...")
gat_backbone =
Edifice.build(:gat,
input_dim: input_dim,
hidden_size: 8,
num_heads: 4,
num_classes: 16,
num_layers: 2,
activation: :elu,
dropout: 0.0
)
# Global mean pool + classification head
gat_model =
gat_backbone
|> Axon.nx(fn x -> Nx.mean(x, axes: [1]) end, name: "gat_pool")
|> Axon.dense(num_classes, name: "gat_head")
|> Axon.activation(:softmax)
IO.puts("Training GAT (10 epochs)...")
{gat_state, gat_preds, gat_acc} =
GraphTrainer.train_and_eval(gat_model, train_data, test_nodes, test_adj, test_y, epochs: 10)
IO.puts("GAT test accuracy: #{Float.round(gat_acc * 100, 1)}%")
Train GIN (Graph Isomorphism Network)
GIN is provably the most expressive GNN under the WL (Weisfeiler-Leman) test. It uses MLPs instead of simple linear transforms and a learnable self-weight parameter epsilon.
IO.puts("Building GIN model...")
gin_backbone =
Edifice.build(:gin,
input_dim: input_dim,
hidden_dims: [32, 16],
epsilon_learnable: true,
activation: :relu,
dropout: 0.0
)
# Global mean pool + classification head
gin_model =
gin_backbone
|> Axon.nx(fn x -> Nx.mean(x, axes: [1]) end, name: "gin_pool")
|> Axon.dense(num_classes, name: "gin_head")
|> Axon.activation(:softmax)
IO.puts("Training GIN (10 epochs)...")
{gin_state, gin_preds, gin_acc} =
GraphTrainer.train_and_eval(gin_model, train_data, test_nodes, test_adj, test_y, epochs: 10)
IO.puts("GIN test accuracy: #{Float.round(gin_acc * 100, 1)}%")
Compare Results
All three GNNs were trained on the same data with the same hyperparameters. The only difference is the message passing mechanism — how each model aggregates information from neighboring nodes.
The task requires detecting feature-topology correlations:
- Smooth graphs have similar features on connected nodes
- Noisy graphs have random, uncorrelated features
- Alternating graphs have opposite-sign features on connected nodes
A model that can’t look at neighbor relationships (e.g., a simple MLP on pooled features) would struggle. GNNs that better capture edge-level patterns should perform better.
results = [
{"GCN", "spectral conv", gcn_acc},
{"GAT", "attention", gat_acc},
{"GIN", "WL-expressive MLP", gin_acc}
]
ranked = Enum.sort_by(results, &elem(&1, 2), :desc)
IO.puts("=" |> String.duplicate(55))
IO.puts(" #{String.pad_trailing("Rank", 6)}#{String.pad_trailing("Model", 10)}#{String.pad_trailing("Mechanism", 22)}Accuracy")
IO.puts(" " <> String.duplicate("-", 50))
ranked
|> Enum.with_index(1)
|> Enum.each(fn {{name, mechanism, acc}, rank} ->
IO.puts(
" #{String.pad_trailing("##{rank}", 6)}" <>
"#{String.pad_trailing(name, 10)}" <>
"#{String.pad_trailing(mechanism, 22)}" <>
"#{Float.round(acc * 100, 1)}%"
)
end)
IO.puts("=" |> String.duplicate(55))
The bar chart below shows test accuracy for each architecture. Higher bars mean the model better distinguishes the three feature pattern classes. The color indicates the aggregation mechanism used.
chart_data =
results
|> Enum.map(fn {name, mechanism, acc} ->
%{"Model" => name, "Accuracy" => Float.round(acc * 100, 1), "Mechanism" => mechanism}
end)
Vl.new(width: 400, height: 250, title: "Graph Classification: Detecting Feature-Topology Patterns")
|> Vl.data_from_values(chart_data)
|> Vl.mark(:bar)
|> Vl.encode_field(:x, "Model", type: :nominal, sort: "-y")
|> Vl.encode_field(:y, "Accuracy", type: :quantitative, scale: %{domain: [0, 100]}, title: "Test Accuracy (%)")
|> Vl.encode_field(:color, "Mechanism", type: :nominal)
Key Takeaways
-
Message passing matters: All three classes share the same random topology. The signal is in how features relate across edges — the model must aggregate neighbor information to distinguish smooth, noisy, and alternating patterns.
-
Different aggregation, different strengths: GCN, GAT, and GIN all do message passing but weight neighbors differently:
- GCN: Fixed weights from degree normalization — good at detecting smoothness
- GAT: Learned attention weights — adapts to which neighbors are informative
- GIN: MLP-based with learnable self-weight — most theoretically expressive
-
Global pooling bridges node → graph: Node-level GNNs produce per-node features. For graph-level tasks, global mean pooling compresses all node features into a single graph representation.
-
Same Edifice API: Every model was built with
Edifice.build/2and trained with the same loop — swapping GNN architectures is trivial.
What’s Next?
-
Try more GNN architectures:
:graph_sage,:pna,:graph_transformer,:schnet— all use the same input format. - Node classification: Skip the global pooling step and classify individual nodes (e.g., community detection in social networks).
-
Larger graphs: Scale up
num_nodesand add more complex topologies (trees, grids, random graphs with planted communities). - Real datasets: Use molecular graphs, citation networks, or social graphs.