Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Your first evaluation loop

your_first_evaluation_loop.livemd

Your first evaluation loop

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Creating an Axon evaluation loop

Once you have a trained model, it’s necessary to test the trained model on some test data. Axon’s loop abstraction is general enough to work for both training and evaluating models. Just as Axon implements a canned Axon.Loop.trainer/3 factory, it also implements a canned Axon.Loop.evaluator/1 factory.

Axon.Loop.evaluator/1 creates an evaluation loop which you can instrument with metrics to measure the performance of a trained model on test data. First, you need a trained model:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

train_loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

trained_model_state = Axon.Loop.run(train_loop, data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.1285532
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.06848274916410446, 0.037988610565662384, -0.199247345328331, 0.18008524179458618, 0.10976515710353851, -0.10479626059532166, 0.562850832939148, -0.030415315181016922]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.2839881181716919, 0.11133058369159698, -0.5213645100593567, -0.14406965672969818, 0.37532612681388855, -0.28965434432029724, -0.9048429131507874, -5.540614947676659e-4]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.2961483597755432, 0.3721822202205658, -0.1726730614900589, -0.20648165047168732]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.602420449256897, 0.46551579236984253, 0.3295630216598511, 0.484800785779953],
        [0.05755739286541939, -0.2412092238664627, 0.27874955534935, 0.13457047939300537],
        [-0.26997247338294983, -0.4479314386844635, 0.4976465106010437, -0.05715075880289078],
        [-0.7245721220970154, 0.1187945082783699, 0.14330074191093445, 0.3257679343223572],
        [-0.032964885234832764, -0.625235915184021, -0.05669135972857475, -0.7016372680664062],
        [-0.08433973789215088, -0.07334757596254349, 0.08273869007825851, 0.46893611550331116],
        [0.4123252332210541, 0.9876810312271118, -0.3525731563568115, 0.030163511633872986],
        [0.6962482333183289, 0.5394620299339294, 0.6907036304473877, -0.5448697209358215]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.7519291043281555]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.7839917540550232],
        [-0.8586246967315674],
        [0.8599083423614502],
        [0.29766184091567993]
      ]
    >
  }
}

Running loops with Axon.Loop.trainer/3 returns a trained model state which you can use to evaluate your model. To construct an evaluation loop, you just call Axon.Loop.evaluator/1 with your pre-trained model:

test_loop = Axon.Loop.evaluator(model)
#Axon.Loop<
  metrics: %{},
  handlers: %{
    completed: [],
    epoch_completed: [],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Next, you’ll need to instrument your test loop with the metrics you’d like to aggregate:

test_loop = test_loop |> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
  metrics: %{
    "mean_absolute_error" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     :mean_absolute_error}
  },
  handlers: %{
    completed: [],
    epoch_completed: [],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Finally, you can run your loop on test data. Because you want to test your trained model, you need to provide your model’s initial state to the test loop:

Axon.Loop.run(test_loop, data, trained_model_state, iterations: 1000)
Batch: 999, mean_absolute_error: 0.0856894
%{
  0 => %{
    "mean_absolute_error" => #Nx.Tensor<
      f32
      0.08568935841321945
    >
  }
}