Custom Models, Loss Functions, Optimizers
Mix.install([
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.2"},
{:kino, "~> 0.9.0"}
])
:ok
Using Custom Models In Training Loops
An example of Axon
model and loop looks like:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
Some problems will require more flexibility than this example offers
If a model cannot be cleanly represented as an %Axon{}
struct model, you can define custom initialization and forward functions to pass to Axon.Loop.trainer/3
. In the example above, Axon.Loop.trainer/3
is doing this under the hood - the ability to pass a %Axon{}
struct directly is just a convience:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
lowered_model = {init_fn, predict_fn} = Axon.build(model)
loop = Axon.Loop.trainer(lowered_model, :mean_squared_error, :sgd)
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<9.98177675/2 in Axon.Loop.build_loss_fn/1>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Axon.Loop.trainer/3
handles the lowered_model
as if it were a %Axon{}
struct. With this in mind, you can build custom models entirely with Nx.Defn
, or mix Axon
models into custom workflows without worrying about compatibilty with the Axon.Loop
API:
defmodule CustomModel do
import Nx.Defn
defn custom_predict_fn(model_predict_fn, params, input) do
%{prediction: pred} = out = model_predict_fn.(params, input)
%{out | prediction: Nx.cos(pred)}
end
end
{:module, CustomModel, <<70, 79, 82, 49, 0, 0, 10, ...>>, true}
train_data =
Stream.repeatedly(fn ->
{xs, _key} = Nx.Random.normal(Nx.Random.key(12), shape: {8, 1}, type: :f32)
ys = Nx.sin(xs)
{xs, ys}
end)
{init_fn, predict_fn} = Axon.build(model, mode: :train)
custom_predict_fn = &CustomModel.custom_predict_fn(predict_fn, &1, &2)
loop = Axon.Loop.trainer({init_fn, custom_predict_fn}, :mean_squared_error, :sgd)
Axon.Loop.run(loop, train_data, %{}, iterations: 500)
Epoch: 0, Batch: 450, loss: 0.1479121
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.02175050973892212, -0.024910716339945793, -0.027416784316301346, 0.033370327204465866, -0.034723542630672455, -0.004877300467342138, 0.3793628215789795, -0.03329168260097504]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[0.04658525809645653, 0.19475071132183075, 0.7931413650512695, 0.09596283733844757, 0.5304712057113647, 0.1936696171760559, -0.769289493560791, -0.3590540289878845]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.3208722174167633, -0.005868331994861364, 0.07400587946176529, 0.3720073103904724]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.4856124818325043, 0.32140636444091797, -0.36212342977523804, -0.3598131537437439],
[0.2552179992198944, -0.09555363655090332, 0.5524453520774841, -0.48476120829582214],
[0.28624099493026733, 0.3768715262413025, 0.06508781760931015, -0.5240411758422852],
[0.13131645321846008, 0.2961307168006897, 0.6918494701385498, 0.3586469292640686],
[0.627220630645752, 0.2460222989320755, 0.2787061333656311, -0.609135627746582],
[-0.18439742922782898, -0.008864992298185825, -0.026968808844685555, -0.5638307332992554],
[0.8725758790969849, -0.2888437807559967, 0.06409312039613724, 0.5787912607192993],
[0.0492698960006237, -0.11134690046310425, 0.25074824690818787, -0.09627816081047058]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.4735690653324127]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.8874920606613159],
[0.46658992767333984],
[0.1805259734392166],
[0.8596954941749573]
]
>
}
}
Custom Loss Functions In Training Loops
Axon.Loop.traniner/3
allows flexibility with models, as well as flexible loss functions. In most cases you can use Axon’s built-in loss functions, however you can pass custom arity-2 functions to Axon.Loops.trainer/3
to control some parameters of the loss function, such as batch-level reduction:
loss_fn = &Axon.Losses.mean_squared_error(&1, &2, reduction: :sum)
loop = Axon.Loop.trainer(model, loss_fn, :sgd)
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.71264915/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<6.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.98177675/1 in Axon.Loop.log/3>,
#Function<64.98177675/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
Custom functions are useful for constructing loss functions when dealing with multi-output scenaiors
Example, you can construct a custom loss function which is a weighted average of several loss functions on multiple inputs:
train_data =
Stream.repeatedly(fn ->
key = Nx.Random.key(42)
{xs, _key} = Nx.Random.normal(key, shape: {8, 1}, type: :f32)
y1 = Nx.sin(xs)
y2 = Nx.cos(xs)
{xs, {y1, y2}}
end)
shared =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
y1 = shared |> Axon.dense(1)
y2 = shared |> Axon.dense(1)
model = Axon.container({y1, y2})
custom_loss_fn = fn {y_true1, y_true2}, {y_pred1, y_pred2} ->
loss1 = Axon.Losses.mean_squared_error(y_true1, y_pred1, reduction: :mean)
loss2 = Axon.Losses.mean_squared_error(y_true2, y_pred2, reduction: :mean)
loss1
|> Nx.multiply(0.4)
|> Nx.add(Nx.multiply(loss2, 0.6))
end
model
|> Axon.Loop.trainer(custom_loss_fn, :sgd)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.2421960
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.034086067229509354, -0.02887660264968872, 0.06434519588947296, 0.08234142512083054, -0.01846095360815525, 0.018978577107191086, 0.029167955741286278, 0.03630779683589935]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.6758522987365723, -0.44307130575180054, 0.7554517984390259, 0.8388501405715942, -0.645117461681366, -0.32366904616355896, -0.5417147278785706, -0.49977046251296997]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.0, 0.252582311630249, 0.0, -0.007978496141731739]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.39820462465286255, -0.5602824687957764, 0.46009451150894165, 0.6020047068595886],
[-0.22417116165161133, 0.4916485846042633, -0.509750247001648, -0.7039790749549866],
[-0.683228075504303, 0.30494335293769836, -0.3050630986690521, -0.6625939607620239],
[-0.5826095342636108, 0.37313663959503174, 0.19205337762832642, -0.024342477321624756],
[-0.4203922152519226, 0.32724741101264954, -0.28685835003852844, -0.5794124603271484],
[0.027150988578796387, -0.3362065553665161, -0.21735718846321106, 0.5498802661895752],
[-0.5828167200088501, -0.44295307993888855, -0.06499296426773071, 0.1972251981496811],
[0.38042229413986206, -0.45279115438461304, -0.2764470875263214, -0.6100720763206482]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.4143042266368866]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.09080958366394043],
[1.3575913906097412],
[1.003105878829956],
[-0.6902222633361816]
]
>
},
"dense_3" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.7182903289794922]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.02720022201538086],
[-0.28363344073295593],
[0.15775752067565918],
[-0.7115132212638855]
]
>
}
}
Custom Optimizers In Training Loops
It is possible to customize the optimizer passed to Axon.Loop.trainer/3
Optimizers are represented as the tuple {init_fn, predict_fn}
:
-
init_fn
initializes optimizer state from model state -
update_fn
scales gradients from optimizer state, gradients and model state
Implementing a custom optimizer is unlikely however knowing how to construct optimizers with different hyperparameters, and applying different modifiers to different optimizers to customize operations, is useful to know.
Specifying an optimizer as an atom in Axon.Loop.trainer/3
maps directly to an optimizer declared in Axon.Optimizers
. Instead an optimizer can be declared directly. This is useful for controlling things like the learning rate and various optimizer hyperparameters:
train_data =
Stream.repeatedly(fn ->
{xs, _key} = Nx.Random.normal(Nx.Random.key(42), shape: {8, 1}, type: :f32)
ys = Nx.sin(xs)
{xs, ys}
end)
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
optimizer = {_init_optimizer_fn, _update_optimizer_fn} = Axon.Optimizers.sgd(1.0e-3)
model
|> Axon.Loop.trainer(:mean_squared_error, optimizer)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.1864013
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.046244218945503235, -0.07283938676118851, 0.03872031718492508, 0.04726196452975273, 0.05153937265276909, -0.09562944620847702, 0.08030311018228531, -9.8411797080189e-4]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.44380950927734375, -0.09863922744989395, -0.09823662042617798, -0.06565620750188828, -0.5508302450180054, 0.6184558868408203, -0.6625171899795532, -0.43435895442962646]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[-0.09076668322086334, 0.09131301194429398, 0.09503045678138733, -0.009358537383377552]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.21457846462726593, 0.05630738288164139, 0.23056626319885254, 0.45693546533584595],
[-0.6530879735946655, -0.6910097599029541, 0.4409450590610504, -0.01907256990671158],
[0.5752553343772888, -0.35199815034866333, 0.2332945168018341, 0.22411414980888367],
[-0.12412989884614944, 0.568100094795227, 0.07800319045782089, 0.24700255692005157],
[0.06662838906049728, 0.15959638357162476, 0.2596869170665741, -0.10141480714082718],
[0.5309587121009827, -0.28541138768196106, -0.33925706148147583, -0.39941707253456116],
[0.14302359521389008, 0.45056530833244324, 0.14198145270347595, -0.28399235010147095],
[-0.25966623425483704, 0.39908555150032043, -0.16964033246040344, 0.10002800077199936]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.1894342601299286]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.42044392228126526],
[-0.5125951170921326],
[-0.48303771018981934],
[0.038217440247535706]
]
>
}
}