Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Axon: Training & Inference Mode

training_and_inference.livemd

Axon: Training & Inference Mode

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Executing Models In Inference Mode

Layers have different considerations and behavior when running during model training versus model inference

Dropout Layers are intended only to be used during training as a form of model regularization. Stateful layers such as batch normalization keep a running-internal state which changes during training mode but remains fixed during inference mode.

By default all models build in inference mode. You can see this by adding a dropout layer with a dropout rate of 1; in inference mode this layer will have no effect:

inputs = Nx.iota({2, 8}, type: :f32)
#Nx.Tensor<
  f32[2][8]
  [
    [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
    [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]
  ]
>
model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.sigmoid()
  |> Axon.dropout(rate: 0.99)
  |> Axon.dense(1)
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_1"
  nodes: 5
>
# Building a model without a mode will default to `:inference`
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  [
    [0.8433471322059631],
    [0.9331091046333313]
  ]
>

Which mode your model’s compiled for is important, as running a model built with :inference mode will behave drastically different than a model built in :train mode

# Building a model with specific mode, `:inference`
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  [
    [-0.07459478825330734],
    [0.6787327527999878]
  ]
>
# Building a model with specific mode, `:train`
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
  prediction: #Nx.Tensor<
    f32[2][1]
    [
      [0.0],
      [0.0]
    ]
  >,
  state: %{
    "dropout_0" => %{
      "key" => #Nx.Tensor<
        u32[2]
        [2064909249, 2249250809]
      >
    }
  }
}

By specifying the :train mode. The model now returns a map with keys :prediction and :state where the :prediction key contains the actual model prediction and :state contains the updated state for any stateful layers such as batch normalization

When writing custom training loops, you should extract :state and use it in conjunction with the updates API to ensure your stateful layers are updated correctly. If your model has stateful layers, :state will look similar to your model’s parameter map

model =
  Axon.input("data")
  |> Axon.dense(4)
  |> Axon.sigmoid()
  |> Axon.batch_norm()
  |> Axon.dense(1)

{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
  prediction: #Nx.Tensor<
    f32[2][1]
    [
      [-1.0132904052734375],
      [1.0132904052734375]
    ]
  >,
  state: %{
    "batch_norm_0" => %{
      "mean" => #Nx.Tensor<
        f32[4]
        [0.033600956201553345, 1.6433799464721233e-4, 0.5530032515525818, 5.167277413420379e-4]
      >,
      "var" => #Nx.Tensor<
        f32[4]
        [0.1007978543639183, 0.10000003129243851, 0.23339000344276428, 0.10000029951334]
      >
    }
  }
}