Loop Event Handlers
Mix.install([
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.2"},
{:kino, "~> 0.9.0"}
])
:ok
Adding Event Handlers To Training Loops
If you want more fine-grained control over things that happen during loop execution
- Save loop state to a file every 500 iterations
-
Log some output to
:stdout
at end of every epoch
Axon loops allow fine-grained control via events and event handlers
Axon emits a number of events during loop execution allowing you to instrument various points in the loop execution cycle.
List of events you can attach handlers to:
-
:started
- Loop state initializated -
:epoch_started
- Epoch start -
:iteration_started
- Iteration start -
:iteration_completed
- Iteration complete -
:epoch_completed
- Epoch complete -
:epoch_halted
- Epoch halt, if early halted -
:halted
- Loop halt, if early halted -
:completed
- Loop completion
Axon comes with a number of common loop event handlers. You can also implement custom event handlers
Example:
Checkpoint loop state at the end of every epoch, using Axon.Loop.checkpoint/2
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(event: :epoch_completed)
training_data =
Stream.repeatedly(fn ->
key = Nx.Random.key(12)
{xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, training_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.1117618
Epoch: 1, Batch: 50, loss: 0.0589791
Epoch: 2, Batch: 50, loss: 0.0413947
Epoch: 3, Batch: 50, loss: 0.0324704
Epoch: 4, Batch: 50, loss: 0.0268887
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.0244132149964571, -0.008191048167645931, 0.054081644862890244, 0.1797507405281067, -0.02473694272339344, -0.08548571914434433, -0.011347572319209576, 0.0666012093424797]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.40107280015945435, -0.5721201300621033, 0.281085342168808, -0.566183865070343, -0.03017902374267578, -0.11561445146799088, 0.285394549369812, -0.7834172248840332]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.11261016130447388, 0.009246677160263062, -0.07615330070257187, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.055334143340587616, -0.017638223245739937, 0.528224766254425, -0.01342695951461792],
[0.044971831142902374, -0.6853394508361816, -0.6868073344230652, -0.28147637844085693],
[-0.45176321268081665, -0.20226849615573883, 0.08495717495679855, -0.30261707305908203],
[0.7845761775970459, 0.5105684995651245, 0.23007333278656006, -0.6273692846298218],
[-0.11239924281835556, 0.2928570508956909, -0.24946483969688416, -0.18267256021499634],
[-0.3222370743751526, -0.24477338790893555, 0.1783885359764099, 0.35081833600997925],
[-0.19781574606895447, 0.015273826196789742, 0.553498387336731, -0.2342393696308136],
[0.32456278800964355, -0.3851204514503479, -0.5878689289093018, -0.4877678155899048]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.31968435645103455]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-1.0996549129486084],
[-0.5105833411216736],
[-0.1323658525943756],
[0.27660489082336426]
]
>
}
}