Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Custom Metrics

writing_custom_metrics.livemd

Custom Metrics

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Writing Custom Metrics

Passing an atom to Axon.Loop.metric/5 dispatches the function to a built-in function in Axon.Metrics. A custom metric can be defined as such:

defmodule CustomMetric do
  import Nx.Defn

  defn my_custom_metric(y_true, y_pred) do
    Nx.atan2(y_true, y_pred) |> Nx.sum()
  end
end
{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}

This can be passed to Axon.Loop.metric/5, provided you name your custom metric:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_2"
  nodes: 6
>
loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(&amp;CustomMetric.my_custom_metric/2, "my_custom_metric")
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>},
    "my_custom_metric" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     &CustomMetric.my_custom_metric/2}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Axon will invoke this custom metric, when the model is running, and accumulate the value with the given aggregator

training_data =
  Stream.repeatedly(fn ->
    {xs, _key} = Nx.Random.normal(Nx.Random.key(42), shape: {8, 1}, type: :f32)
    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.3437554 my_custom_metric: -3.9039953
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.11073245108127594, 0.0, 0.0, -0.0834122970700264, 0.0, 0.09130118042230606, 0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.1805345118045807, 0.06610244512557983, 0.06647324562072754, -0.43841561675071716, 0.16585946083068848, -0.35881850123405457, 0.7354707717895508, 0.8001827001571655]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0, 0.0, -0.09117528796195984, -0.09919869154691696]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.1822679042816162, 0.019030332565307617, 0.5965149998664856, 0.5283539295196533],
        [0.5678933262825012, 0.6286323666572571, 0.5543228983879089, 0.6253339648246765],
        [-0.39208537340164185, -0.5719695091247559, 0.132618248462677, 0.6295531392097473],
        [-0.5305015444755554, -0.6904010772705078, 0.4426437020301819, 0.3491297960281372],
        [-0.07654678821563721, -0.05949413776397705, 0.33350545167922974, -0.15483957529067993],
        [-0.4237406849861145, -0.282043993473053, -0.6067726016044617, -0.42265596985816956],
        [-0.10928106307983398, -0.14439982175827026, -0.508826732635498, -0.20586729049682617],
        [-0.3498375415802002, -0.4012027680873871, 0.26174354553222656, -0.4216057062149048]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.10541863739490509]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.4311734437942505],
        [0.13204562664031982],
        [1.0326485633850098],
        [0.9910651445388794]
      ]
    >
  }
}

Metric defaults are designed with supervised training loops in mind but can be used for more flexible purposes

By default, metrics look for fields :y_true and :y_pred in the given loops step state and then apply the given metric function on those inputs

You can also define metrics which work on other fields, for example you can track the running average of a given parameter with a metric by defining a custom output transform:

output_transform = fn %{model_state: model_state} ->
  [model_state["dense_0"]["kernel"]]
end

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(&amp;Nx.mean/1, "dense_0_kernel_mean", :running_average, output_transform)
  |> Axon.Loop.metric(&amp;Nx.variance/1, "dense_0_kernel_var", :running_average, output_transform)
#Axon.Loop<
  metrics: %{
    "dense_0_kernel_mean" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     &Nx.mean/1},
    "dense_0_kernel_var" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     &Nx.variance/1},
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>
loop |> Axon.Loop.run(training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_mean: 0.2376680 dense_0_kernel_var: 0.2236163 loss: 0.0259660
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.03513307124376297, -0.08354594558477402, -0.030258579179644585, -0.018628861755132675, 0.07863052934408188, 0.008323105983436108, 0.04463294520974159, 0.014289754442870617]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.24527588486671448, -0.5853115916252136, -0.21179081499576569, 0.6120597124099731, 0.771464467048645, 0.2930528223514557, 0.7350512146949768, 0.4327988922595978]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.31833764910697937, -0.06402218341827393, 0.15779836475849152, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.35651424527168274, -0.3132264316082001, 0.21152308583259583, -0.3248668909072876],
        [-0.7181836366653442, 0.15176981687545776, -0.27426889538764954, -0.36469903588294983],
        [0.04260712116956711, -0.5604554414749146, -0.5009223818778992, -0.3872546851634979],
        [0.10704389959573746, -0.12545841932296753, -0.6765053868293762, -0.5981687903404236],
        [0.23998293280601501, -0.2473406195640564, 0.6491498947143555, 0.1822124719619751],
        [0.3256910741329193, 0.6383374333381653, 0.35975417494773865, 0.6543763279914856],
        [0.2979046106338501, 0.3126327395439148, 0.6056838035583496, -0.24469786882400513],
        [-0.05811864510178566, 0.1268378496170044, 0.4055303931236267, -0.5605016350746155]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.6836467981338501]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [1.0598795413970947],
        [-0.9941802024841309],
        [0.7693229913711548],
        [-0.5177003145217896]
      ]
    >
  }
}

Custom acculation functions can also be used inside an Axon.Loop

An accumulator must be an artiy-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. The function can be passed directly to the metric function

defmodule CustomAccumulator do
  import Nx.Defn

  defn running_ema(acc, obs, _i, opts \\ []) do
    opts = keyword!(opts, alpha: 0.9)
    obs * opts[:alpha] + acc * (1 - opts[:alpha])
  end
end
{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(
    &amp;Nx.mean/1,
    "dense_0_kernel_ema_mean",
    &amp;CustomAccumulator.running_ema/3,
    output_transform
  )
#Axon.Loop<
  metrics: %{
    "dense_0_kernel_ema_mean" => {#Function<15.98177675/3 in Axon.Loop.build_metric_fn/3>,
     &Nx.mean/1},
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Now we can run the training loop with our custom accumulator

loop |> Axon.Loop.run(training_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_ema_mean: -0.1315047 loss: 0.0247939
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.038095198571681976, -0.035317715257406235, 0.10211353749036789, 0.04737640172243118, -0.03097797930240631, 0.30161386728286743, -0.027741162106394768, -0.05312076210975647]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.8210890889167786, -0.30364781618118286, 0.09091158956289291, -0.3785530924797058, -0.47972002625465393, 0.7027488946914673, 0.20813235640525818, -0.07428683340549469]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.013171808794140816, 0.15184389054775238, 0.10273980349302292, 0.046200577169656754]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.5861905217170715, -0.1695214956998825, -0.2709495723247528, -0.5087029337882996],
        [0.5260724425315857, -0.03290720283985138, -0.6362824440002441, 0.04115370288491249],
        [-0.10268115997314453, 0.14633500576019287, 0.2654731571674347, 0.1771700382232666],
        [-0.3430379033088684, 0.4363419711589813, 0.1411917358636856, 0.12639537453651428],
        [-0.18108224868774414, -0.5600121021270752, 0.4511440396308899, 0.5870872139930725],
        [0.2919427752494812, 0.21837227046489716, 0.6466642618179321, 0.5795313715934753],
        [-0.2304805964231491, 0.2866005599498749, -0.6626536846160889, -0.15277843177318573],
        [-0.6399926543235779, -0.015960492193698883, 0.07125378400087357, 0.36992549896240234]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.7076235413551331]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.2274997979402542],
        [1.0127692222595215],
        [0.9669621586799622],
        [0.6742429137229919]
      ]
    >
  }
}