First Evaluation Loop
Mix.install([
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.2"},
{:kino, "~> 0.9.0"}
])
:ok
Creating An Axon Evaluation Loop
After training a model, it’s necessary to test the trained model on some test data. Axon’s loop abstraction is general enough to work for both training and evaluating models.
Axon implements a canned Axon.Loop.trainer/3
factory and a canned Axon.Loop.evaluator/1
factory.
Axon.Loop.evaluator/1
creates an evaluation loop which you can instrument with metrics to measure the performance of a trained model on test data.
First you need a trained model:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
#Axon<
inputs: %{"data" => nil}
outputs: "dense_2"
nodes: 6
>
train_loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Running loops with Axon.Loop.trainer/3
returns a trained model state which you can use to evaluate your model.
data =
Stream.repeatedly(fn ->
key = Nx.Random.key(12)
{xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
ys = Nx.sin(xs)
{xs, ys}
end)
#Function<51.124013645/2 in Stream.repeatedly/1>
trained_model_state = Axon.Loop.run(train_loop, data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0077294
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.11962035298347473, 0.011666420847177505, 0.14720754325389862, 0.11118166148662567, -0.1345834732055664, 0.021384937688708305, -0.006526924669742584, 0.009598659351468086]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.44526207447052, -0.26303717494010925, -0.2527385950088501, -0.4609968960285187, 0.32450586557388306, -0.7373477220535278, 0.36933162808418274, -0.7535463571548462]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.036067795008420944, -0.039795882999897, 0.10334350168704987, 0.1229325458407402]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.4655035436153412, -0.5213115215301514, -0.5214490294456482, -0.61592698097229],
[-0.19082710146903992, 0.2520129978656769, 0.21478496491909027, -0.3245740830898285],
[0.6956093311309814, 0.6102574467658997, 0.40522530674934387, 0.18446679413318634],
[-0.057972073554992676, 0.41674870252609253, 0.4854397773742676, -0.4421413838863373],
[0.4250357747077942, 0.1620386838912964, 0.03752923756837845, -0.390951007604599],
[0.2880731225013733, 0.048166424036026, 0.08799079805612564, -0.0927252247929573],
[-0.3562321960926056, 0.1249040961265564, -0.6461682915687561, 0.3380368947982788],
[0.10793770104646683, -0.11531627178192139, 0.1639910340309143, -0.0112814512103796]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.3604468107223511]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.8113043904304504],
[-0.9013109803199768],
[-0.9164133667945862],
[0.36915871500968933]
]
>
}
}
To construct an evaluation loop, use Axon.Loop.evaluator/1
with your pre-trained model:
eval_loop = Axon.Loop.evaluator(model)
#Axon.Loop<
metrics: %{},
handlers: %{
completed: [],
epoch_completed: [],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Instrument your evaluation loop with the metrics you’d like to aggregate:
eval_loop =
eval_loop
|> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
metrics: %{
"mean_absolute_error" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
:mean_absolute_error}
},
handlers: %{
completed: [],
epoch_completed: [],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Run your evaluation loop on the test data. Testing your trained model requires you to provide your model’s initial state to test the loop:
eval_loop |> Axon.Loop.run(data, trained_model_state, iterations: 1000)
Batch: 999, mean_absolute_error: 0.0277175
%{
0 => %{
"mean_absolute_error" => #Nx.Tensor<
f32
0.027717459946870804
>
}
}