Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Your first training loop

your_first_training_loop.livemd

Your first training loop

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Creating an Axon training loop

Axon generalizes the concept of training, evaluation, hyperparameter optimization, and more into the Axon.Loop API. Axon loops are a instrumented reductions over Elixir Streams - that basically means you can accumulate some state over an Elixir Stream and control different points in the loop execution.

With Axon, you’ll most commonly implement and work with supervised training loops. Because supervised training loops are so common in deep learning, Axon has a loop factory function which takes care of most of the boilerplate of creating a supervised training loop for you. In the beginning of your deep learning journey, you’ll almost exclusively use Axon’s loop factories to create and run loops.

Axon’s supervised training loop assumes you have an input stream of data with entries that look like:

{batch_inputs, batch_labels}

Each entry is a batch of input data with a corresponding batch of labels. You can simulate some real training data by constructing an Elixir stream:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)
#Function<51.6935098/2 in Stream.repeatedly/1>

The most basic supervised training loop in Axon requires 3 things:

  1. An Axon model
  2. A loss function
  3. An optimizer

You can construct an Axon model using the knowledge you’ve gained from going through the model creation guides:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_2"
  nodes: 6
>

Axon comes with built-in loss functions and optimizers which you can use directly when constructing your training loop. To construct your training loop, you use Axon.Loop.trainer/3:

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

You’ll notice that Axon.Loop.trainer/3 returns an %Axon.Loop{} data structure. This data structure contains information which Axon uses to control the execution of the loop. In order to run the loop, you need to explicitly pass it to Axon.Loop.run/4:

Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0563023
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.038592107594013214, 0.19925688207149506, -0.08018972724676132, -0.11267539858818054, 0.35166260600090027, -0.0794963389635086, 0.20298318564891815, 0.3049686849117279]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.06691190600395203, -0.32860732078552246, 0.22386932373046875, 0.16137443482875824, 0.23626506328582764, 0.2438151240348816, 0.2662005126476288, 0.32266947627067566]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.03138260543346405, 0.2621246576309204, 0.021843062713742256, -0.07498764991760254]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.541576087474823, 0.4923045039176941, 0.5933979749679565, -0.5083895921707153],
        [0.5120893120765686, -0.6925638318061829, 0.36635661125183105, -0.05748361349105835],
        [0.26158788800239563, -0.1788359135389328, -0.14064575731754303, -0.08323567360639572],
        [0.6685130596160889, -0.4880330264568329, 0.5104460120201111, -0.3399733006954193],
        [-0.6356683969497681, 0.770803689956665, -0.3876360058784485, -0.5178110599517822],
        [0.4476216733455658, -0.21042484045028687, -0.4300518333911896, -0.2693784534931183],
        [0.08789066225290298, 0.47043612599372864, 0.02871485985815525, 0.6908602714538574],
        [0.45776790380477905, 0.6735268235206604, 0.40828803181648254, 0.19558420777320862]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.748963475227356]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.22219088673591614],
        [1.1391150951385498],
        [-0.13221295177936554],
        [-0.27904900908470154]
      ]
    >
  }
}

Axon.Loop.run/4 expects a loop to execute, some data to loop over, and any initial state you explicitly want your loop to start with. Axon.Loop.run/4 will then iterate over your data, executing a step function on each batch, and accumulating some generic loop state. In the case of a supervised training loop, this generic loop state actually represents training state including your model’s trained parameters.

Axon.Loop.run/4 also accepts options which control the loops execution. This includes :iterations which controls the number of iterations per epoch a loop should execute for, and :epochs which controls the number of epochs a loop should execute for:

Axon.Loop.run(loop, train_data, %{}, epochs: 3, iterations: 500)
Epoch: 0, Batch: 450, loss: 0.0935063
Epoch: 1, Batch: 450, loss: 0.0576384
Epoch: 2, Batch: 450, loss: 0.0428323
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.035534460097551346, 0.2604885697364807, -0.10573504120111465, -0.16461455821990967, 0.3610309064388275, -0.10921606421470642, 0.2061888873577118, 0.3162775933742523]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.05344606190919876, -0.3463115096092224, 0.23782028257846832, 0.20592278242111206, 0.2195105254650116, 0.2618684470653534, 0.2559347450733185, 0.3006669282913208]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.03086121939122677, 0.28601887822151184, 0.02634759061038494, -0.08197703212499619]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.5404174327850342, 0.49248307943344116, 0.5927202701568604, -0.5083895921707153],
        [0.5133915543556213, -0.7197086811065674, 0.3669036030769348, -0.057483553886413574],
        [0.26609811186790466, -0.20234307646751404, -0.14102067053318024, -0.08141336590051651],
        [0.673393964767456, -0.512398362159729, 0.5106634497642517, -0.3384905159473419],
        [-0.6347945928573608, 0.7695014476776123, -0.3877493143081665, -0.5186421275138855],
        [0.45236992835998535, -0.2351287305355072, -0.4305106997489929, -0.2674770951271057],
        [0.08871842920780182, 0.46521952748298645, 0.02729635499417782, 0.691332221031189],
        [0.4584391117095947, 0.6687410473823547, 0.4068295657634735, 0.19576647877693176]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.7425869703292847]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.24965399503707886],
        [1.1746525764465332],
        [-0.12984804809093475],
        [-0.2796761095523834]
      ]
    >
  }
}

You may have noticed that by default Axon.Loop.trainer/3 configures your loop to log information about training progress every 50 iterations. You can control this when constructing your supervised training loop with the :log option:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 100)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 900, loss: 0.1492715
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.09267199039459229, 0.5775123834609985, -0.07691138982772827, 0.04283804073929787, -0.015639742836356163, -0.0725373700261116, -0.10598818212747574, 0.021243896335363388]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.07886508852243423, 0.826379120349884, 0.1022031158208847, -0.5164816975593567, 0.390212744474411, 0.2709604799747467, -0.05409134551882744, -0.6204537749290466]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.09577611088752747, 0.3303026556968689, -0.25102874636650085, -0.3312375247478485]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.5508446097373962, -0.03904113546013832, 0.382876992225647, -0.6273598670959473],
        [0.13289013504981995, 0.947068452835083, -0.27359727025032043, 0.4073275923728943],
        [-0.10011858493089676, -0.32976964116096497, -0.3160743713378906, -0.3586210012435913],
        [-0.628970205783844, -0.19567319750785828, -0.07241304218769073, -0.43270331621170044],
        [-0.6155693531036377, -0.020595157518982887, -0.3254905045032501, 0.18614870309829712],
        [-0.07561944425106049, -0.34477049112319946, -0.30149057507514954, -0.6603768467903137],
        [-0.17559891939163208, -0.2768605649471283, 0.5830116868019104, 0.11386138200759888],
        [-0.6376093626022339, -0.31125709414482117, 0.2749727964401245, -0.6777774691581726]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.767456591129303]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.3530634641647339],
        [0.9497018456459045],
        [0.31334763765335083],
        [-0.624195396900177]
      ]
    >
  }
}