Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

First Evaluation Loop

elixir/first_evaluation_loop.livemd

First Evaluation Loop

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Creating An Axon Evaluation Loop

After training a model, it’s necessary to test the trained model on some test data. Axon’s loop abstraction is general enough to work for both training and evaluating models.

Axon implements a canned Axon.Loop.trainer/3 factory and a canned Axon.Loop.evaluator/1 factory.

Axon.Loop.evaluator/1 creates an evaluation loop which you can instrument with metrics to measure the performance of a trained model on test data.

First you need a trained model:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_2"
  nodes: 6
>
train_loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Running loops with Axon.Loop.trainer/3 returns a trained model state which you can use to evaluate your model.

data =
  Stream.repeatedly(fn ->
    key = Nx.Random.key(12)
    {xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
    ys = Nx.sin(xs)
    {xs, ys}
  end)
#Function<51.124013645/2 in Stream.repeatedly/1>
trained_model_state = Axon.Loop.run(train_loop, data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0077294
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.11962035298347473, 0.011666420847177505, 0.14720754325389862, 0.11118166148662567, -0.1345834732055664, 0.021384937688708305, -0.006526924669742584, 0.009598659351468086]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.44526207447052, -0.26303717494010925, -0.2527385950088501, -0.4609968960285187, 0.32450586557388306, -0.7373477220535278, 0.36933162808418274, -0.7535463571548462]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.036067795008420944, -0.039795882999897, 0.10334350168704987, 0.1229325458407402]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.4655035436153412, -0.5213115215301514, -0.5214490294456482, -0.61592698097229],
        [-0.19082710146903992, 0.2520129978656769, 0.21478496491909027, -0.3245740830898285],
        [0.6956093311309814, 0.6102574467658997, 0.40522530674934387, 0.18446679413318634],
        [-0.057972073554992676, 0.41674870252609253, 0.4854397773742676, -0.4421413838863373],
        [0.4250357747077942, 0.1620386838912964, 0.03752923756837845, -0.390951007604599],
        [0.2880731225013733, 0.048166424036026, 0.08799079805612564, -0.0927252247929573],
        [-0.3562321960926056, 0.1249040961265564, -0.6461682915687561, 0.3380368947982788],
        [0.10793770104646683, -0.11531627178192139, 0.1639910340309143, -0.0112814512103796]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.3604468107223511]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.8113043904304504],
        [-0.9013109803199768],
        [-0.9164133667945862],
        [0.36915871500968933]
      ]
    >
  }
}

To construct an evaluation loop, use Axon.Loop.evaluator/1 with your pre-trained model:

eval_loop = Axon.Loop.evaluator(model)
#Axon.Loop<
  metrics: %{},
  handlers: %{
    completed: [],
    epoch_completed: [],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Instrument your evaluation loop with the metrics you’d like to aggregate:

eval_loop =
  eval_loop
  |> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
  metrics: %{
    "mean_absolute_error" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     :mean_absolute_error}
  },
  handlers: %{
    completed: [],
    epoch_completed: [],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Run your evaluation loop on the test data. Testing your trained model requires you to provide your model’s initial state to test the loop:

eval_loop |> Axon.Loop.run(data, trained_model_state, iterations: 1000)
Batch: 999, mean_absolute_error: 0.0277175
%{
  0 => %{
    "mean_absolute_error" => #Nx.Tensor<
      f32
      0.027717459946870804
    >
  }
}