Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Loop Event Handlers

elixir/loop_event_handlers.livemd

Loop Event Handlers

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Adding Event Handlers To Training Loops

If you want more fine-grained control over things that happen during loop execution

  • Save loop state to a file every 500 iterations
  • Log some output to :stdout at end of every epoch

Axon loops allow fine-grained control via events and event handlers

Axon emits a number of events during loop execution allowing you to instrument various points in the loop execution cycle.

List of events you can attach handlers to:

  • :started - Loop state initializated
  • :epoch_started - Epoch start
  • :iteration_started - Iteration start
  • :iteration_completed - Iteration complete
  • :epoch_completed - Epoch complete
  • :epoch_halted - Epoch halt, if early halted
  • :halted - Loop halt, if early halted
  • :completed - Loop completion

Axon comes with a number of common loop event handlers. You can also implement custom event handlers

Example:

Checkpoint loop state at the end of every epoch, using Axon.Loop.checkpoint/2

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.checkpoint(event: :epoch_completed)

training_data =
  Stream.repeatedly(fn ->
    key = Nx.Random.key(12)
    {xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, training_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.1117618
Epoch: 1, Batch: 50, loss: 0.0589791
Epoch: 2, Batch: 50, loss: 0.0413947
Epoch: 3, Batch: 50, loss: 0.0324704
Epoch: 4, Batch: 50, loss: 0.0268887
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.0244132149964571, -0.008191048167645931, 0.054081644862890244, 0.1797507405281067, -0.02473694272339344, -0.08548571914434433, -0.011347572319209576, 0.0666012093424797]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.40107280015945435, -0.5721201300621033, 0.281085342168808, -0.566183865070343, -0.03017902374267578, -0.11561445146799088, 0.285394549369812, -0.7834172248840332]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.11261016130447388, 0.009246677160263062, -0.07615330070257187, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.055334143340587616, -0.017638223245739937, 0.528224766254425, -0.01342695951461792],
        [0.044971831142902374, -0.6853394508361816, -0.6868073344230652, -0.28147637844085693],
        [-0.45176321268081665, -0.20226849615573883, 0.08495717495679855, -0.30261707305908203],
        [0.7845761775970459, 0.5105684995651245, 0.23007333278656006, -0.6273692846298218],
        [-0.11239924281835556, 0.2928570508956909, -0.24946483969688416, -0.18267256021499634],
        [-0.3222370743751526, -0.24477338790893555, 0.1783885359764099, 0.35081833600997925],
        [-0.19781574606895447, 0.015273826196789742, 0.553498387336731, -0.2342393696308136],
        [0.32456278800964355, -0.3851204514503479, -0.5878689289093018, -0.4877678155899048]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.31968435645103455]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-1.0996549129486084],
        [-0.5105833411216736],
        [-0.1323658525943756],
        [0.27660489082336426]
      ]
    >
  }
}