Saving and Loading Models
Section
Mix.install([
{:axon, "~> 0.8"}
])
Overview
Axon recommends a parameters-only approach to saving models: serialize only the trained parameters (weights) using Nx.serialize/2 and Nx.deserialize/2, and keep the model definition in your code. This approach:
- Avoids serialization issues with anonymous functions and complex model structures
- Makes the model structure explicit and version-controlled in code
- Works reliably across processes and deployments
The model itself is just code, you define it once and reuse it. Only the learned parameters need to be persisted.
Saving a Model After Training
When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)
trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)
The training loop returns model_state by default (from Axon.Loop.trainer/3). For inference, we need the parameters—extract the data field from ModelState:
# Extract parameters - trained_model_state.data contains the nested map of weights
params = trained_model_state.data
# Serialize and save
params_bytes = Nx.serialize(params)
File.write!("model_params.axon", params_bytes)
Loading a Model for Inference
To load and run inference, you need:
- The model definition (in code—the same structure you trained)
- The saved parameters
# 1. Define the same model structure (must match training)
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
# 2. Load parameters
params = File.read!("model_params.axon") |> Nx.deserialize()
# 3. Run inference
input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input)
Axon.predict(model, params, %{"data" => input})
Checkpointing During Training
To save checkpoints during training (e.g., every epoch or when validation improves), use Axon.Loop.checkpoint/2. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later.
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed)
train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)
Checkpoints are saved to the checkpoints/ directory, as configured above. Each file contains the serialized loop state from Axon.Loop.serialize_state/2.
Resuming from a Checkpoint
To resume training from a saved checkpoint:
-
Load the checkpoint with
Axon.Loop.deserialize_state/2 -
Attach it to your loop with
Axon.Loop.from_state/2 - Run the loop as usual
# Load the checkpoint (use the path from your checkpoint files)
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
serialized = File.read!(checkpoint_path)
state = Axon.Loop.deserialize_state(serialized)
# Resume training
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.from_state(state)
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50)
Saving Only Parameters from a Checkpoint
If you have a checkpoint file and want to extract parameters for inference (without optimizer state):
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state()
# Extract model parameters from step_state
%{model_state: model_state} = state.step_state
params = model_state.data
# Save for inference
File.write!("model_params.axon", Nx.serialize(params))
Summary
| Use Case | Save | Load |
|---|---|---|
| Inference only |
Nx.serialize(params) → file |
Nx.deserialize(file) + model in code |
| Checkpoint (resume training) |
Axon.Loop.checkpoint/2 or Axon.Loop.serialize_state/2 |
Axon.Loop.deserialize_state/2 + Axon.Loop.from_state/2 |
| Extract params from checkpoint |
state.step_state.model_state.data → Nx.serialize |
Use with model in code |