Training and inference mode
Mix.install([
{:axon, ">= 0.5.0"}
])
:ok
Executing models in inference mode
Some layers have different considerations and behavior when running during model training versus model inference. For example dropout layers are intended only to be used during training as a form of model regularization. Certain stateful layers like batch normalization keep a running-internal state which changes during training mode but remains fixed during inference mode. Axon supports mode-dependent execution behavior via the :mode
option passed to all building, compilation, and execution methods. By default, all models build in inference mode. You can see this behavior by adding a dropout layer with a dropout rate of 1. In inference mode this layer will have no affect:
inputs = Nx.iota({2, 8}, type: :f32)
model =
Axon.input("data")
|> Axon.dense(4)
|> Axon.sigmoid()
|> Axon.dropout(rate: 0.99)
|> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
f32[2][1]
[
[0.6900148391723633],
[1.1159517765045166]
]
>
You can also explicitly specify the mode:
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
f32[2][1]
[
[-1.1250841617584229],
[-1.161189317703247]
]
>
It’s important that you know which mode your model’s were compiled for, as running a model built in :inference
mode will behave drastically different than a model built in :train
mode.
Executing models in training mode
By specifying mode: :train
, you tell your models to execute in training mode. You can see the effects of this behavior here:
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
prediction: #Nx.Tensor<
f32[2][1]
[
[0.0],
[0.0]
]
>,
state: %{
"dropout_0" => %{
"key" => #Nx.Tensor<
u32[2]
[309162766, 2699730300]
>
}
}
}
First, notice that your model now returns a map with keys :prediction
and :state
. :prediction
contains the actual model prediction, while :state
contains the updated state for any stateful layers such as batch norm. When writing custom training loops, you should extract :state
and use it in conjunction with the updates API to ensure your stateful layers are updated correctly. If your model has stateful layers, :state
will look similar to your model’s parameter map:
model =
Axon.input("data")
|> Axon.dense(4)
|> Axon.sigmoid()
|> Axon.batch_norm()
|> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
prediction: #Nx.Tensor<
f32[2][1]
[
[0.4891311526298523],
[-0.4891311228275299]
]
>,
state: %{
"batch_norm_0" => %{
"mean" => #Nx.Tensor<
f32[4]
[0.525083601474762, 0.8689039349555969, 0.03931800276041031, 0.0021854371298104525]
>,
"var" => #Nx.Tensor<
f32[4]
[0.13831248879432678, 0.10107331722974777, 0.10170891880989075, 0.10000484436750412]
>
}
}
}