Instrumenting loops with metrics
Mix.install([
{:axon, ">= 0.5.0"}
])
:ok
Adding metrics to training loops
Often times when executing a loop you want to keep track of various metrics such as accuracy or precision. For training loops, Axon by default only tracks loss; however, you can instrument the loop with additional built-in metrics. For example, you might want to track mean-absolute error on top of a mean-squared error loss:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
#Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>},
"mean_absolute_error" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
:mean_absolute_error}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.37390314/1 in Axon.Loop.log/3>,
#Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.37390314/1 in Axon.Loop.log/3>,
#Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
When specifying a metric, you can specify an atom which maps to any of the metrics defined in Axon.Metrics
. You can also define custom metrics. For more information on custom metrics, see Writing custom metrics.
When you run a loop with metrics, Axon will aggregate that metric over the course of the loop execution. For training loops, Axon will also report the aggregate metric in the training logs:
train_data =
Stream.repeatedly(fn ->
{xs, _next_key} =
:random.uniform(9999)
|> Nx.Random.key()
|> Nx.Random.normal(shape: {8, 1})
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0590630 mean_absolute_error: 0.1463431
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.015203186310827732, 0.1997198462486267, 0.09740892797708511, -0.007404750678688288, 0.11397464573383331, 0.3608400523662567, 0.07219560444355011, -0.06638865917921066]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[0.07889414578676224, 0.30445051193237305, 0.1377921849489212, 0.015571207739412785, 0.7115736603736877, -0.6404237151145935, 0.25553327798843384, 0.057831913232803345]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.10809992998838425, 0.0, 0.47775307297706604, -0.1641010195016861]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.040330830961465836, -0.36995524168014526, 0.001599793671630323, 0.6012424826622009],
[0.21044284105300903, -0.39482879638671875, -0.5866784453392029, 0.15573620796203613],
[-0.09234675765037537, 0.27758270502090454, -0.6663768291473389, 0.6017312407493591],
[-0.4454570412635803, 0.1304328441619873, -0.31381309032440186, 0.1906844824552536],
[0.3460652530193329, -0.3017694056034088, -0.1680794507265091, -0.47811293601989746],
[0.28633055090904236, -0.34003201127052307, 0.6202688813209534, 0.18027405440807343],
[0.5729941129684448, 0.32222074270248413, 0.20647864043712616, 0.02462891861796379],
[-0.13146185874938965, -0.06700503826141357, 0.6600251793861389, -0.06442582607269287]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.4863035976886749]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.41491562128067017],
[-0.948100209236145],
[-1.2559744119644165],
[1.0097774267196655]
]
>
}
}
By default, the metric will have a name which matches the string form of the given metric. You can give metrics semantic meaning by providing an explicit name:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "model error")
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0607362 model error: 0.1516546
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.2577069401741028, 0.16761353611946106, 0.11587327718734741, 0.28539595007896423, -0.2071152776479721, -0.02039412036538124, -0.11152249574661255, 0.2389308214187622]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.1265750676393509, 0.6902633309364319, -0.10233660787343979, -0.2544037103652954, -0.26677289605140686, -0.31035077571868896, 0.3845033347606659, -0.33032187819480896]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.0, 0.16427761316299438, 0.02123815007507801, 0.22260485589504242]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.3859425485134125, 0.49959924817085266, -0.34108400344848633, 0.6222119331359863],
[-0.43326857686042786, -0.42272067070007324, 0.04245679825544357, -0.4357914626598358],
[-0.3065953850746155, 0.587925374507904, 0.2960704267024994, -0.31594154238700867],
[-0.35595524311065674, 0.6649497747421265, 0.4832736849784851, 0.3025558590888977],
[0.048333823680877686, -0.17023107409477234, 0.09139639884233475, -0.6511918902397156],
[-0.12099027633666992, -0.02014642395079136, 0.025831595063209534, -0.09945832937955856],
[0.3415437340736389, 0.41694650053977966, 0.24677544832229614, 0.06690020114183426],
[-0.1977071762084961, 0.39345067739486694, 0.26068705320358276, 0.35502269864082336]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.8329466581344604]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.23763614892959595],
[-1.031561255455017],
[0.1092313677072525],
[-0.7191486358642578]
]
>
}
}
Axon’s default aggregation behavior is to aggregate metrics with a running average; however, you can customize this behavior by specifying an explicit accumulation function. Built-in accumulation functions are :running_average
and :running_sum
:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "total error", :running_sum)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0688004 total error: 151.4876404
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.34921368956565857, 0.2217460423707962, 0.274880051612854, 0.016405446454882622, -0.11720903217792511, -0.20693546533584595, 0.14232252538204193, -0.07956698536872864]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.37851807475090027, -0.17135880887508392, -0.3878959119319916, 0.19248774647712708, 0.12453905493021011, -0.2750281095504761, 0.5614567995071411, 0.6186240315437317]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[-0.28566694259643555, 0.27262070775032043, -0.2875851094722748, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.23161421716213226, 0.8222984671592712, 0.09437259286642075, -0.4825701117515564],
[-0.38828352093696594, 0.6247998476028442, 0.5035035610198975, 0.0026152729988098145],
[0.5202338099479675, 0.7906754612922668, 0.08624745905399323, -0.5285568833351135],
[0.47950035333633423, -0.07571044564247131, 0.32921522855758667, -0.7011756896972656],
[-0.3601212203502655, 0.44817543029785156, 0.13981425762176514, -0.01014477014541626],
[-0.3157005310058594, -0.6309216618537903, 0.5622371435165405, 0.27447545528411865],
[-0.5749425292015076, -0.5073797702789307, -0.3527824282646179, 0.08027392625808716],
[-0.5331286191940308, 0.15432128310203552, -0.015716910362243652, -0.5225256681442261]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.8275660872459412]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.45810666680336],
[-1.0092405080795288],
[0.5322748422622681],
[-0.5989866852760315]
]
>
}
}