Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Custom Models, Loss Functions, Optimizers

custom_models_loss_fns_optimizers.livemd

Custom Models, Loss Functions, Optimizers

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, "~> 0.5.2"},
  {:kino, "~> 0.9.0"}
])
:ok

Using Custom Models In Training Loops

An example of Axon model and loop looks like:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

Some problems will require more flexibility than this example offers

If a model cannot be cleanly represented as an %Axon{} struct model, you can define custom initialization and forward functions to pass to Axon.Loop.trainer/3. In the example above, Axon.Loop.trainer/3 is doing this under the hood - the ability to pass a %Axon{} struct directly is just a convience:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

lowered_model = {init_fn, predict_fn} = Axon.build(model)

loop = Axon.Loop.trainer(lowered_model, :mean_squared_error, :sgd)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Axon.Loop.trainer/3 handles the lowered_model as if it were a %Axon{} struct. With this in mind, you can build custom models entirely with Nx.Defn, or mix Axon models into custom workflows without worrying about compatibilty with the Axon.Loop API:

defmodule CustomModel do
  import Nx.Defn

  defn custom_predict_fn(model_predict_fn, params, input) do
    %{prediction: pred} = out = model_predict_fn.(params, input)
    %{out | prediction: Nx.cos(pred)}
  end
end
{:module, CustomModel, <<70, 79, 82, 49, 0, 0, 10, ...>>, true}
train_data =
  Stream.repeatedly(fn ->
    {xs, _key} = Nx.Random.normal(Nx.Random.key(12), shape: {8, 1}, type: :f32)
    ys = Nx.sin(xs)
    {xs, ys}
  end)

{init_fn, predict_fn} = Axon.build(model, mode: :train)
custom_predict_fn = &amp;CustomModel.custom_predict_fn(predict_fn, &amp;1, &amp;2)

loop = Axon.Loop.trainer({init_fn, custom_predict_fn}, :mean_squared_error, :sgd)

Axon.Loop.run(loop, train_data, %{}, iterations: 500)
Epoch: 0, Batch: 450, loss: 0.1479121
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.02175050973892212, -0.024910716339945793, -0.027416784316301346, 0.033370327204465866, -0.034723542630672455, -0.004877300467342138, 0.3793628215789795, -0.03329168260097504]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.04658525809645653, 0.19475071132183075, 0.7931413650512695, 0.09596283733844757, 0.5304712057113647, 0.1936696171760559, -0.769289493560791, -0.3590540289878845]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.3208722174167633, -0.005868331994861364, 0.07400587946176529, 0.3720073103904724]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.4856124818325043, 0.32140636444091797, -0.36212342977523804, -0.3598131537437439],
        [0.2552179992198944, -0.09555363655090332, 0.5524453520774841, -0.48476120829582214],
        [0.28624099493026733, 0.3768715262413025, 0.06508781760931015, -0.5240411758422852],
        [0.13131645321846008, 0.2961307168006897, 0.6918494701385498, 0.3586469292640686],
        [0.627220630645752, 0.2460222989320755, 0.2787061333656311, -0.609135627746582],
        [-0.18439742922782898, -0.008864992298185825, -0.026968808844685555, -0.5638307332992554],
        [0.8725758790969849, -0.2888437807559967, 0.06409312039613724, 0.5787912607192993],
        [0.0492698960006237, -0.11134690046310425, 0.25074824690818787, -0.09627816081047058]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.4735690653324127]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.8874920606613159],
        [0.46658992767333984],
        [0.1805259734392166],
        [0.8596954941749573]
      ]
    >
  }
}

Custom Loss Functions In Training Loops

Axon.Loop.traniner/3 allows flexibility with models, as well as flexible loss functions. In most cases you can use Axon’s built-in loss functions, however you can pass custom arity-2 functions to Axon.Loops.trainer/3 to control some parameters of the loss function, such as batch-level reduction:

loss_fn = &amp;Axon.Losses.mean_squared_error(&amp;1, &amp;2, reduction: :sum)
loop = Axon.Loop.trainer(model, loss_fn, :sgd)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.98177675/1 in Axon.Loop.log/3>,
       #Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Custom functions are useful for constructing loss functions when dealing with multi-output scenaiors

Example, you can construct a custom loss function which is a weighted average of several loss functions on multiple inputs:

train_data =
  Stream.repeatedly(fn ->
    key = Nx.Random.key(42)
    {xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
    y1 = Nx.sin(xs)
    y2 = Nx.cos(xs)
    {xs, {y1, y2}}
  end)

shared =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()

y1 = shared |> Axon.dense(1)
y2 = shared |> Axon.dense(1)

model = Axon.container({y1, y2})

custom_loss_fn = fn {y_true1, y_true2}, {y_pred1, y_pred2} ->
  loss1 = Axon.Losses.mean_squared_error(y_true1, y_pred1, reduction: :mean)
  loss2 = Axon.Losses.mean_squared_error(y_true2, y_pred2, reduction: :mean)

  loss1
  |> Nx.multiply(0.4)
  |> Nx.add(Nx.multiply(loss2, 0.6))
end

model
|> Axon.Loop.trainer(custom_loss_fn, :sgd)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.2421960
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.034086067229509354, -0.02887660264968872, 0.06434519588947296, 0.08234142512083054, -0.01846095360815525, 0.018978577107191086, 0.029167955741286278, 0.03630779683589935]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.6758522987365723, -0.44307130575180054, 0.7554517984390259, 0.8388501405715942, -0.645117461681366, -0.32366904616355896, -0.5417147278785706, -0.49977046251296997]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0, 0.252582311630249, 0.0, -0.007978496141731739]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.39820462465286255, -0.5602824687957764, 0.46009451150894165, 0.6020047068595886],
        [-0.22417116165161133, 0.4916485846042633, -0.509750247001648, -0.7039790749549866],
        [-0.683228075504303, 0.30494335293769836, -0.3050630986690521, -0.6625939607620239],
        [-0.5826095342636108, 0.37313663959503174, 0.19205337762832642, -0.024342477321624756],
        [-0.4203922152519226, 0.32724741101264954, -0.28685835003852844, -0.5794124603271484],
        [0.027150988578796387, -0.3362065553665161, -0.21735718846321106, 0.5498802661895752],
        [-0.5828167200088501, -0.44295307993888855, -0.06499296426773071, 0.1972251981496811],
        [0.38042229413986206, -0.45279115438461304, -0.2764470875263214, -0.6100720763206482]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.4143042266368866]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.09080958366394043],
        [1.3575913906097412],
        [1.003105878829956],
        [-0.6902222633361816]
      ]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.7182903289794922]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.02720022201538086],
        [-0.28363344073295593],
        [0.15775752067565918],
        [-0.7115132212638855]
      ]
    >
  }
}

Custom Optimizers In Training Loops

It is possible to customize the optimizer passed to Axon.Loop.trainer/3

Optimizers are represented as the tuple {init_fn, predict_fn}:

  • init_fn initializes optimizer state from model state
  • update_fn scales gradients from optimizer state, gradients and model state

Implementing a custom optimizer is unlikely however knowing how to construct optimizers with different hyperparameters, and applying different modifiers to different optimizers to customize operations, is useful to know.

Specifying an optimizer as an atom in Axon.Loop.trainer/3 maps directly to an optimizer declared in Axon.Optimizers. Instead an optimizer can be declared directly. This is useful for controlling things like the learning rate and various optimizer hyperparameters:

train_data =
  Stream.repeatedly(fn ->
    {xs, _key} = Nx.Random.normal(Nx.Random.key(42), shape: {8, 1}, type: :f32)
    ys = Nx.sin(xs)
    {xs, ys}
  end)

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

optimizer = {_init_optimizer_fn, _update_optimizer_fn} = Axon.Optimizers.sgd(1.0e-3)

model
|> Axon.Loop.trainer(:mean_squared_error, optimizer)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.1864013
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.046244218945503235, -0.07283938676118851, 0.03872031718492508, 0.04726196452975273, 0.05153937265276909, -0.09562944620847702, 0.08030311018228531, -9.8411797080189e-4]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.44380950927734375, -0.09863922744989395, -0.09823662042617798, -0.06565620750188828, -0.5508302450180054, 0.6184558868408203, -0.6625171899795532, -0.43435895442962646]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.09076668322086334, 0.09131301194429398, 0.09503045678138733, -0.009358537383377552]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.21457846462726593, 0.05630738288164139, 0.23056626319885254, 0.45693546533584595],
        [-0.6530879735946655, -0.6910097599029541, 0.4409450590610504, -0.01907256990671158],
        [0.5752553343772888, -0.35199815034866333, 0.2332945168018341, 0.22411414980888367],
        [-0.12412989884614944, 0.568100094795227, 0.07800319045782089, 0.24700255692005157],
        [0.06662838906049728, 0.15959638357162476, 0.2596869170665741, -0.10141480714082718],
        [0.5309587121009827, -0.28541138768196106, -0.33925706148147583, -0.39941707253456116],
        [0.14302359521389008, 0.45056530833244324, 0.14198145270347595, -0.28399235010147095],
        [-0.25966623425483704, 0.39908555150032043, -0.16964033246040344, 0.10002800077199936]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.1894342601299286]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.42044392228126526],
        [-0.5125951170921326],
        [-0.48303771018981934],
        [0.038217440247535706]
      ]
    >
  }
}