Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Instrumenting Loops With Metrics

instrumenting_loop_metrics.livemd

Instrumenting Loops With Metrics

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Adding Metrics To Training Loops

When executing a loop you want to keep track of various metrics such as accuracy or precision. For training loops, Axon by default only tracks loss; however you can loop with additional metrics. Example, you want to track mean-absolute error on top of a mean-squared error loss:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>},
    "mean_absolute_error" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     :mean_absolute_error}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

You can also define custom metrics, Axon will aggregate that metric over the course of the loop execution. For training loops, Axon will report the aggregate metric in the training logs:

train_data =
  Stream.repeatedly(fn ->
    key = Nx.Random.key(12)
    {normal, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
    ys = Nx.sin(normal)
    {normal, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, iterations: 1050)
Epoch: 0, Batch: 1000, loss: 0.0888227 mean_absolute_error: 0.1962541
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.4943743944168091, -0.0640738233923912, -0.04723932966589928, -0.03521409630775452, -0.08424084633588791, 0.09580688923597336, -0.05233990028500557, -0.07276623696088791]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.9043301343917847, -0.044184163212776184, -0.05026792362332344, 0.07078481465578079, 0.6130713820457458, -0.6679699420928955, -0.03548271954059601, -0.04985514655709267]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.4905800521373749, -0.18391939997673035, 0.0, -0.14162522554397583]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.9534140825271606, 0.4382154047489166, -0.08498793840408325, -0.5544222593307495],
        [0.6379451155662537, -0.06168210506439209, -0.5657287240028381, 0.4666915237903595],
        [0.206525057554245, -0.6354535818099976, 0.08223742246627808, -0.07512380927801132],
        [-0.5516510009765625, 0.5059526562690735, 0.5206479430198669, 0.632449209690094],
        [-0.13534902036190033, -0.4325638711452484, -0.5679105520248413, -0.1079898476600647],
        [-0.512778639793396, -0.6849693655967712, -0.40579289197921753, 0.13528507947921753],
        [-0.5750519037246704, 0.24609410762786865, 0.3988637328147888, 0.5886136889457703],
        [-0.13678640127182007, -0.6555773615837097, 0.40399235486984253, 0.6027266383171082]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.937074601650238]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [1.029587745666504],
        [-0.11171838641166687],
        [0.2759326696395874],
        [0.6078693866729736]
      ]
    >
  }
}

By default, the metric will have a name which matches the string form of the given metric. You can give metrics semantic meaning by providing an explicit name:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "model error")
|> Axon.Loop.run(train_data, %{}, iterations: 1050)
Epoch: 0, Batch: 1000, loss: 0.0157658 model error: 0.0586748
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.059182167053222656, 0.022478463128209114, -0.17774473130702972, 0.2213028371334076, -0.09098079800605774, 0.10855386406183243, 0.21272781491279602, 0.008826442062854767]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.11995638161897659, -0.15990103781223297, 0.3607592284679413, -0.6160328388214111, 0.18441420793533325, 0.12754520773887634, 0.5297830700874329, 0.699347734451294]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0, 0.17139868438243866, 0.10636390000581741, 0.12659816443920135]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.11982297897338867, 0.06972303986549377, 0.4129018187522888, -0.38851556181907654],
        [0.2401699423789978, 0.15984180569648743, 0.3136225640773773, 0.4181918203830719],
        [0.32672661542892456, 0.6614510416984558, 0.05174720287322998, -0.3830629885196686],
        [-0.18236470222473145, 0.7005639672279358, 0.25697338581085205, -0.2589600384235382],
        [-0.3322817087173462, 0.4810338318347931, -0.44291195273399353, -0.42512592673301697],
        [0.40222984552383423, 0.22190448641777039, -0.3441125750541687, 0.6704994440078735],
        [-0.049820780754089355, -0.35739046335220337, -0.43987011909484863, 0.5273388028144836],
        [-0.4311349391937256, -0.11345041543245316, -0.41187217831611633, 0.03270530700683594]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.01986159197986126]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.5068124532699585],
        [-0.8305477499961853],
        [-0.3594668209552765],
        [0.9469127058982849]
      ]
    >
  }
}

The default aggregation behavior is to aggregate metrics with :running_average. You can customize this behavior by specifying an explicit accumulation function. Built-in accumulation functions are :running_average and :running_sum:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "total error", :running_sum)
|> Axon.Loop.run(train_data, %{}, iterations: 1050)
Epoch: 0, Batch: 1000, loss: 0.0170177 total error: 60.6604500
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.06552094221115112, 0.04294583946466446, 0.18359117209911346, -0.016857502982020378, 0.06263178586959839, 0.05358298495411873, -0.014969314448535442, 0.03554303199052811]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.20268964767456055, 0.34049317240715027, -0.4355761706829071, -0.01127737108618021, -0.6687942743301392, -0.10873300582170486, -0.23162637650966644, -0.5642837882041931]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.03859004005789757, -0.2329762578010559, 0.34153735637664795, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.3286629021167755, 0.6469427347183228, 0.5788326859474182, 0.4313785433769226],
        [0.4343699812889099, 0.5978931784629822, 0.7139090299606323, -0.6280259490013123],
        [0.7432681322097778, -0.3315945267677307, -0.6036061644554138, -0.4142929017543793],
        [-0.5269539952278137, 0.17591893672943115, 0.6815183758735657, 0.48768287897109985],
        [0.537314772605896, 0.4009576439857483, -0.2885156273841858, -0.23214280605316162],
        [0.5944074988365173, 0.603762149810791, -0.23866161704063416, 0.5527281165122986],
        [0.04886094480752945, -0.006500720977783203, -0.2641407549381256, -0.3267301321029663],
        [0.3275148868560791, -0.5457772016525269, 0.09545277059078217, 0.5026876330375671]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.20475178956985474]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.5014837980270386],
        [-0.9767601490020752],
        [1.1795883178710938],
        [-0.18678522109985352]
      ]
    >
  }
}