Custom Metrics
Mix.install([
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.2"},
{:kino, "~> 0.9.0"}
])
:ok
Writing Custom Metrics
Passing an atom to Axon.Loop.metric/5
dispatches the function to a built-in function in Axon.Metrics
. A custom metric can be defined as such:
defmodule CustomMetric do
import Nx.Defn
defn my_custom_metric(y_true, y_pred) do
Nx.atan2(y_true, y_pred) |> Nx.sum()
end
end
{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
This can be passed to Axon.Loop.metric/5
, provided you name your custom metric:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
#Axon<
inputs: %{"data" => nil}
outputs: "dense_2"
nodes: 6
>
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(&CustomMetric.my_custom_metric/2, "my_custom_metric")
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>},
"my_custom_metric" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
&CustomMetric.my_custom_metric/2}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Axon will invoke this custom metric, when the model is running, and accumulate the value with the given aggregator
training_data =
Stream.repeatedly(fn ->
{xs, _key} = Nx.Random.normal(Nx.Random.key(42), shape: {8, 1}, type: :f32)
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.3437554 my_custom_metric: -3.9039953
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.11073245108127594, 0.0, 0.0, -0.0834122970700264, 0.0, 0.09130118042230606, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.1805345118045807, 0.06610244512557983, 0.06647324562072754, -0.43841561675071716, 0.16585946083068848, -0.35881850123405457, 0.7354707717895508, 0.8001827001571655]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.0, 0.0, -0.09117528796195984, -0.09919869154691696]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.1822679042816162, 0.019030332565307617, 0.5965149998664856, 0.5283539295196533],
[0.5678933262825012, 0.6286323666572571, 0.5543228983879089, 0.6253339648246765],
[-0.39208537340164185, -0.5719695091247559, 0.132618248462677, 0.6295531392097473],
[-0.5305015444755554, -0.6904010772705078, 0.4426437020301819, 0.3491297960281372],
[-0.07654678821563721, -0.05949413776397705, 0.33350545167922974, -0.15483957529067993],
[-0.4237406849861145, -0.282043993473053, -0.6067726016044617, -0.42265596985816956],
[-0.10928106307983398, -0.14439982175827026, -0.508826732635498, -0.20586729049682617],
[-0.3498375415802002, -0.4012027680873871, 0.26174354553222656, -0.4216057062149048]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.10541863739490509]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.4311734437942505],
[0.13204562664031982],
[1.0326485633850098],
[0.9910651445388794]
]
>
}
}
Metric defaults are designed with supervised training loops in mind but can be used for more flexible purposes
By default, metrics look for fields :y_true
and :y_pred
in the given loops step state and then apply the given metric function on those inputs
You can also define metrics which work on other fields, for example you can track the running average of a given parameter with a metric by defining a custom output transform:
output_transform = fn %{model_state: model_state} ->
[model_state["dense_0"]["kernel"]]
end
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(&Nx.mean/1, "dense_0_kernel_mean", :running_average, output_transform)
|> Axon.Loop.metric(&Nx.variance/1, "dense_0_kernel_var", :running_average, output_transform)
#Axon.Loop<
metrics: %{
"dense_0_kernel_mean" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
&Nx.mean/1},
"dense_0_kernel_var" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
&Nx.variance/1},
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
loop |> Axon.Loop.run(training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_mean: 0.2376680 dense_0_kernel_var: 0.2236163 loss: 0.0259660
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.03513307124376297, -0.08354594558477402, -0.030258579179644585, -0.018628861755132675, 0.07863052934408188, 0.008323105983436108, 0.04463294520974159, 0.014289754442870617]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.24527588486671448, -0.5853115916252136, -0.21179081499576569, 0.6120597124099731, 0.771464467048645, 0.2930528223514557, 0.7350512146949768, 0.4327988922595978]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.31833764910697937, -0.06402218341827393, 0.15779836475849152, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.35651424527168274, -0.3132264316082001, 0.21152308583259583, -0.3248668909072876],
[-0.7181836366653442, 0.15176981687545776, -0.27426889538764954, -0.36469903588294983],
[0.04260712116956711, -0.5604554414749146, -0.5009223818778992, -0.3872546851634979],
[0.10704389959573746, -0.12545841932296753, -0.6765053868293762, -0.5981687903404236],
[0.23998293280601501, -0.2473406195640564, 0.6491498947143555, 0.1822124719619751],
[0.3256910741329193, 0.6383374333381653, 0.35975417494773865, 0.6543763279914856],
[0.2979046106338501, 0.3126327395439148, 0.6056838035583496, -0.24469786882400513],
[-0.05811864510178566, 0.1268378496170044, 0.4055303931236267, -0.5605016350746155]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.6836467981338501]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[1.0598795413970947],
[-0.9941802024841309],
[0.7693229913711548],
[-0.5177003145217896]
]
>
}
}
Custom acculation functions can also be used inside an Axon.Loop
An accumulator must be an artiy-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. The function can be passed directly to the metric function
defmodule CustomAccumulator do
import Nx.Defn
defn running_ema(acc, obs, _i, opts \\ []) do
opts = keyword!(opts, alpha: 0.9)
obs * opts[:alpha] + acc * (1 - opts[:alpha])
end
end
{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(
&Nx.mean/1,
"dense_0_kernel_ema_mean",
&CustomAccumulator.running_ema/3,
output_transform
)
#Axon.Loop<
metrics: %{
"dense_0_kernel_ema_mean" => {#Function<15.98177675/3 in Axon.Loop.build_metric_fn/3>,
&Nx.mean/1},
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Now we can run the training loop with our custom accumulator
loop |> Axon.Loop.run(training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_ema_mean: -0.1315047 loss: 0.0247939
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.038095198571681976, -0.035317715257406235, 0.10211353749036789, 0.04737640172243118, -0.03097797930240631, 0.30161386728286743, -0.027741162106394768, -0.05312076210975647]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.8210890889167786, -0.30364781618118286, 0.09091158956289291, -0.3785530924797058, -0.47972002625465393, 0.7027488946914673, 0.20813235640525818, -0.07428683340549469]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.013171808794140816, 0.15184389054775238, 0.10273980349302292, 0.046200577169656754]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.5861905217170715, -0.1695214956998825, -0.2709495723247528, -0.5087029337882996],
[0.5260724425315857, -0.03290720283985138, -0.6362824440002441, 0.04115370288491249],
[-0.10268115997314453, 0.14633500576019287, 0.2654731571674347, 0.1771700382232666],
[-0.3430379033088684, 0.4363419711589813, 0.1411917358636856, 0.12639537453651428],
[-0.18108224868774414, -0.5600121021270752, 0.4511440396308899, 0.5870872139930725],
[0.2919427752494812, 0.21837227046489716, 0.6466642618179321, 0.5795313715934753],
[-0.2304805964231491, 0.2866005599498749, -0.6626536846160889, -0.15277843177318573],
[-0.6399926543235779, -0.015960492193698883, 0.07125378400087357, 0.36992549896240234]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.7076235413551331]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.2274997979402542],
[1.0127692222595215],
[0.9669621586799622],
[0.6742429137229919]
]
>
}
}