Powered by AppSignal & Oban Pro

Saving and Loading Models

saving_and_loading.livemd

Saving and Loading Models

Section

Mix.install([
  {:axon, "~> 0.8"}
])

Overview

Axon recommends a parameters-only approach to saving models: serialize only the trained parameters (weights) using Nx.serialize/2 and Nx.deserialize/2, and keep the model definition in your code. This approach:

  • Avoids serialization issues with anonymous functions and complex model structures
  • Makes the model structure explicit and version-controlled in code
  • Works reliably across processes and deployments

The model itself is just code, you define it once and reuse it. Only the learned parameters need to be persisted.

Saving a Model After Training

When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

train_data =
  Stream.repeatedly(fn ->
    {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
    {xs, Nx.sin(xs)}
  end)

trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)

The training loop returns model_state by default (from Axon.Loop.trainer/3). For inference, we need the parameters—extract the data field from ModelState:

# Extract parameters - trained_model_state.data contains the nested map of weights
params = trained_model_state.data

# Serialize and save
params_bytes = Nx.serialize(params)
File.write!("model_params.axon", params_bytes)

Loading a Model for Inference

To load and run inference, you need:

  1. The model definition (in code—the same structure you trained)
  2. The saved parameters
# 1. Define the same model structure (must match training)
model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

# 2. Load parameters
params = File.read!("model_params.axon") |> Nx.deserialize()

# 3. Run inference
input = Nx.tensor([[1.0]])  # shape {1, 1}: 1 sample with 1 feature (matches model input)
Axon.predict(model, params, %{"data" => input})

Checkpointing During Training

To save checkpoints during training (e.g., every epoch or when validation improves), use Axon.Loop.checkpoint/2. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later.

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed)

train_data =
  Stream.repeatedly(fn ->
    {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
    {xs, Nx.sin(xs)}
  end)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)

Checkpoints are saved to the checkpoints/ directory, as configured above. Each file contains the serialized loop state from Axon.Loop.serialize_state/2.

Resuming from a Checkpoint

To resume training from a saved checkpoint:

  1. Load the checkpoint with Axon.Loop.deserialize_state/2
  2. Attach it to your loop with Axon.Loop.from_state/2
  3. Run the loop as usual
# Load the checkpoint (use the path from your checkpoint files)
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
serialized = File.read!(checkpoint_path)
state = Axon.Loop.deserialize_state(serialized)

# Resume training
model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.from_state(state)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50)

Saving Only Parameters from a Checkpoint

If you have a checkpoint file and want to extract parameters for inference (without optimizer state):

checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state()

# Extract model parameters from step_state
%{model_state: model_state} = state.step_state
params = model_state.data

# Save for inference
File.write!("model_params.axon", Nx.serialize(params))

Summary

Use Case Save Load
Inference only Nx.serialize(params) → file Nx.deserialize(file) + model in code
Checkpoint (resume training) Axon.Loop.checkpoint/2 or Axon.Loop.serialize_state/2 Axon.Loop.deserialize_state/2 + Axon.Loop.from_state/2
Extract params from checkpoint state.step_state.model_state.dataNx.serialize Use with model in code