Axon: Training & Inference Mode
Mix.install([
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.2"},
{:kino, "~> 0.9.0"}
])
:ok
Executing Models In Inference Mode
Layers have different considerations and behavior when running during model training versus model inference
Dropout Layers are intended only to be used during training as a form of model regularization. Stateful layers such as batch normalization keep a running-internal state which changes during training mode but remains fixed during inference mode.
By default all models build in inference mode. You can see this by adding a dropout layer with a dropout rate of 1
; in inference mode this layer will have no effect:
inputs = Nx.iota({2, 8}, type: :f32)
#Nx.Tensor<
f32[2][8]
[
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]
]
>
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.sigmoid()
|> Axon.dropout(rate: 0.99)
|> Axon.dense(1)
#Axon<
inputs: %{"data" => nil}
outputs: "dense_1"
nodes: 5
>
# Building a model without a mode will default to `:inference`
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
f32[2][1]
[
[0.8433471322059631],
[0.9331091046333313]
]
>
Which mode your model’s compiled for is important, as running a model built with :inference
mode will behave drastically different than a model built in :train
mode
# Building a model with specific mode, `:inference`
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
f32[2][1]
[
[-0.07459478825330734],
[0.6787327527999878]
]
>
# Building a model with specific mode, `:train`
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
prediction: #Nx.Tensor<
f32[2][1]
[
[0.0],
[0.0]
]
>,
state: %{
"dropout_0" => %{
"key" => #Nx.Tensor<
u32[2]
[2064909249, 2249250809]
>
}
}
}
By specifying the :train
mode. The model now returns a map with keys :prediction
and :state
where the :prediction
key contains the actual model prediction and :state
contains the updated state for any stateful layers such as batch normalization
When writing custom training loops, you should extract :state
and use it in conjunction with the updates API to ensure your stateful layers are updated correctly. If your model has stateful layers, :state
will look similar to your model’s parameter map
model =
Axon.input("data")
|> Axon.dense(4)
|> Axon.sigmoid()
|> Axon.batch_norm()
|> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
prediction: #Nx.Tensor<
f32[2][1]
[
[-1.0132904052734375],
[1.0132904052734375]
]
>,
state: %{
"batch_norm_0" => %{
"mean" => #Nx.Tensor<
f32[4]
[0.033600956201553345, 1.6433799464721233e-4, 0.5530032515525818, 5.167277413420379e-4]
>,
"var" => #Nx.Tensor<
f32[4]
[0.1007978543639183, 0.10000003129243851, 0.23339000344276428, 0.10000029951334]
>
}
}
}