Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Ch 6: Go Deep With Axon

ch6_go_deep_with_axon.livemd

Ch 6: Go Deep With Axon

Mix.install([
  {:nx, "~> 0.7.2"},
  {:axon, "~> 0.6.1"},
  {:exla, "~> 0.7.2"},
  {:scidata, "~> 0.1.11"},
  {:kino, "~> 0.12.3"},
  {:table_rex, "~> 3.1.1"}
])

Understanding the need for Deep Learning

Domains that are rich in information, those that have seemingly complex, unstructured relationships, are often the ones that pose the most challenge. Examples are images, text and audio. It was found that researchers would spend hours or days developing code to extract information about these domains and the process would still be imperfect, sensitive to distribution changes and lossy in nature.

A better way to extract meaningful representations of your data is to delegate that task to your ML algorithm. The ability to extract representations from high-dimensional inputs, and learn to make predictions from those representations, is the draw of DL

The Curse of Dimensionality

The complexity of a ML problem increases significantly as the dimensionality of the inputs increases. The curse of dimensionality arises when the dimensionality of the input space becomes too high and the quality of the model diminishes

Cascading Representation

One theory on how DL is able to overcome the curse of dimensionality lies in how a neural network transforms input data. Neural networks transform inputs into hierarchical representations via composing linear and nonlinear transformations. A neural network does this through it’s layers, the pipeline of linear and nonlinear transformations that an input is processed through before producing a prediction.

The theory of why DL works so well is that deep models are able to learn successive, hierarchical representations of input data. The composition of simple representations, a neural network is able to learn more complex representations, and consequentially, able to make predictions on complex inputs.

Breaking Down a Neural Network

Terminology

DL refers to a subset of ML algorithms that make use of deep models, or artificial neural networks. These models are considered “deep” due to increased number of computational layers in comparison to other methods.

Multi-layer perceptrons (MLPs) are a class of DL models that make use of fully connected layers or densely connected layers; a specific class or architecture of neural network.

Neural Network Anatomy

In a neural network there are generally three classes of layers:

  • Input layers
    • Mostly placeholders for model inputs
  • Hidden layers
    • Commonly densely connected, fully connected or a simple dense layer
    • Often have activation functions that apply a nonlinear function to the output; signal certain input features
      • Common activation functions: ReLU, sigmoid and softmax
  • Output layers
    • Similar to input layers but could have a data transformation to signal an outcome according to the domain of the problem (probability, prediction class, etc.)

Create a Simple Neural Network

defmodule NN do
  import Nx.Defn

  @doc """
  Dense layer that performs matrix multiplication
  of two dense matrices followed by a bias add
  """
  defn dense(input, weight, bias) do
    input
    |> Nx.dot(weight)
    |> Nx.add(bias)
  end

  @doc """
  Sigmoid activation
  """
  defn activation(input) do
    Nx.sigmoid(input)
  end

  @doc """
  Hidden layer that feeds parameters through
  dense layer and activation function
  """
  defn hidden(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  @doc """
  Another dense layer followed by activation function
  """
  defn output(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  @doc """
  Implements simple neural network

  Takes input, trainable parameters and produces an output
  """
  defn predict(input, w1, b1, w2, b2) do
    input
    |> hidden(w1, b1)
    |> output(w2, b2)
  end
end
{:module, NN, <<70, 79, 82, 49, 0, 0, 19, ...>>, true}
key = Nx.Random.key(42)
{w1, new_key} = Nx.Random.uniform(key)
{b1, new_key} = Nx.Random.uniform(new_key)
{w2, new_key} = Nx.Random.uniform(new_key)
{b2, new_key} = Nx.Random.uniform(new_key)
{#Nx.Tensor<
   f32
   0.6716941595077515
 >,
 #Nx.Tensor<
   u32[2]
   [4249898905, 2425127087]
 >}
Nx.Random.uniform_split(new_key, 0.0, 1.0, shape: {})
|> NN.predict(w1, b1, w2, b2)
#Nx.Tensor<
  f32
  0.6633665561676025
>

Neural Networks with Axon

Nx.default_backend(EXLA.Backend)
{Nx.BinaryBackend, []}

Working with Data

{images, labels} = Scidata.MNIST.download()
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}},
 {<<5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8,
    6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, ...>>, {:u, 8}, {60000}}}
{image_data, image_type, image_shape} = images
{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}}
{label_data, label_type, label_shape} = labels
{<<5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8,
   6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, 5, 9, ...>>, {:u, 8}, {60000}}

Data pre-processing

images =
  image_data
  |> Nx.from_binary(image_type)
  # rescale pixel values between 0 and 1
  |> Nx.divide(255)
  # reshape to vector representations
  |> Nx.reshape({60_000, :auto})
#Nx.Tensor<
  f32[60000][784]
  EXLA.Backend
  [
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
    ...
  ]
>
# one-hot encode labels via broadcasting an comparing each label to a vector of numbers
# between 0 and 9
labels =
  label_data
  |> Nx.from_binary(label_type)
  |> Nx.reshape(label_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 10}))
#Nx.Tensor<
  u8[60000][10]
  EXLA.Backend
  [
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    ...
  ]
>

Split data into training and test sets

train_range = 0..49_999//1
test_range = 50_000..-1//1
50000..-1//1
train_images = images[train_range]
train_labels = labels[train_range]
#Nx.Tensor<
  u8[50000][10]
  EXLA.Backend
  [
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    ...
  ]
>
test_images = images[test_range]
test_labels = labels[test_range]
#Nx.Tensor<
  u8[10000][10]
  EXLA.Backend
  [
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    ...
  ]
>

Create mini batches

batch_size = 64
64
# Create stream of batches with training {inputs, targets}
train_data =
  Stream.zip(
    Nx.to_batched(train_images, batch_size),
    Nx.to_batched(train_labels, batch_size)
  )
#Function<73.53678557/2 in Stream.zip_with/2>
# Create stream of batches with testing {inputs, targets}
test_data =
  Stream.zip(
    Nx.to_batched(test_images, batch_size),
    Nx.to_batched(test_labels, batch_size)
  )
#Function<73.53678557/2 in Stream.zip_with/2>

Building Model

model =
  Axon.input("images", shape: {nil, 784})
  # 128 is arbitrary, typically more hidden units == increased representation capacity
  # but too wide or too deep, the model can struggle to learn
  |> Axon.dense(128, activation: :relu)
  # 10 is chosen for activating on whether a image represents 0,1,2...,9
  |> Axon.dense(10, activation: :softmax)
#Axon<
  inputs: %{"images" => {nil, 784}}
  outputs: "softmax_0"
  nodes: 5
>
template = Nx.template({1, 784}, :f32)
Axon.Display.as_graph(model, template)
graph TD;
9[/"images (:input) {1, 784}"/];
10["dense_0 (:dense) {1, 128}"];
11["relu_0 (:relu) {1, 128}"];
12["dense_1 (:dense) {1, 10}"];
13["softmax_0 (:softmax) {1, 10}"];
12 --> 13;
11 --> 12;
10 --> 11;
9 --> 10;
Axon.Display.as_table(model, template) |> IO.puts()
+-----------------------------------------------------------------------------------------------------------+
|                                                   Model                                                   |
+==================================+=============+==============+===================+=======================+
| Layer                            | Input Shape | Output Shape | Options           | Parameters            |
+==================================+=============+==============+===================+=======================+
| images ( input )                 | []          | {1, 784}     | shape: {nil, 784} |                       |
|                                  |             |              | optional: false   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["images"] )      | [{1, 784}]  | {1, 128}     |                   | kernel: f32[784][128] |
|                                  |             |              |                   | bias: f32[128]        |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| relu_0 ( relu["dense_0"] )       | [{1, 128}]  | {1, 128}     |                   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["relu_0"] )      | [{1, 128}]  | {1, 10}      |                   | kernel: f32[128][10]  |
|                                  |             |              |                   | bias: f32[10]         |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_1"] ) | [{1, 10}]   | {1, 10}      |                   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
Total Parameters: 101770
Total Parameters Memory: 407080 bytes
:ok

Training

# stateless representation of a neural network, does not carry model parameters
# or state internally
trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, %{}, epochs: 100, compiler: EXLA)

22:40:15.757 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 750, accuracy: 0.7507078 loss: 0.9997960
Epoch: 1, Batch: 768, accuracy: 0.8799579 loss: 0.7208588
Epoch: 2, Batch: 736, accuracy: 0.8964552 loss: 0.6089117
Epoch: 3, Batch: 754, accuracy: 0.9060845 loss: 0.5389771
Epoch: 4, Batch: 772, accuracy: 0.9123141 loss: 0.4924677
Epoch: 5, Batch: 740, accuracy: 0.9177209 loss: 0.4604695
Epoch: 6, Batch: 758, accuracy: 0.9222456 loss: 0.4335576
Epoch: 7, Batch: 776, accuracy: 0.9266409 loss: 0.4120093
Epoch: 8, Batch: 744, accuracy: 0.9297818 loss: 0.3949130
Epoch: 9, Batch: 762, accuracy: 0.9329948 loss: 0.3791651
Epoch: 10, Batch: 780, accuracy: 0.9352994 loss: 0.3656682
Epoch: 11, Batch: 748, accuracy: 0.9378338 loss: 0.3542692
Epoch: 12, Batch: 766, accuracy: 0.9400872 loss: 0.3433883
Epoch: 13, Batch: 734, accuracy: 0.9421769 loss: 0.3341691
Epoch: 14, Batch: 752, accuracy: 0.9442438 loss: 0.3251619
Epoch: 15, Batch: 770, accuracy: 0.9454240 loss: 0.3168933
Epoch: 16, Batch: 738, accuracy: 0.9471202 loss: 0.3097130
Epoch: 17, Batch: 756, accuracy: 0.9485428 loss: 0.3025968
Epoch: 18, Batch: 774, accuracy: 0.9496573 loss: 0.2959779
Epoch: 19, Batch: 742, accuracy: 0.9513375 loss: 0.2901465
Epoch: 20, Batch: 760, accuracy: 0.9530224 loss: 0.2842627
Epoch: 21, Batch: 778, accuracy: 0.9542683 loss: 0.2787842
Epoch: 22, Batch: 746, accuracy: 0.9555514 loss: 0.2738801
Epoch: 23, Batch: 764, accuracy: 0.9568627 loss: 0.2688854
Epoch: 24, Batch: 732, accuracy: 0.9582196 loss: 0.2644943
Epoch: 25, Batch: 750, accuracy: 0.9591586 loss: 0.2600115
Epoch: 26, Batch: 768, accuracy: 0.9599927 loss: 0.2557245
Epoch: 27, Batch: 736, accuracy: 0.9609270 loss: 0.2518716
Epoch: 28, Batch: 754, accuracy: 0.9616722 loss: 0.2479507
Epoch: 29, Batch: 772, accuracy: 0.9626052 loss: 0.2441924
Epoch: 30, Batch: 740, accuracy: 0.9632886 loss: 0.2408085
Epoch: 31, Batch: 758, accuracy: 0.9640151 loss: 0.2373267
Epoch: 32, Batch: 776, accuracy: 0.9645270 loss: 0.2340053
Epoch: 33, Batch: 744, accuracy: 0.9654572 loss: 0.2309759
Epoch: 34, Batch: 762, accuracy: 0.9662721 loss: 0.2278500
Epoch: 35, Batch: 780, accuracy: 0.9665493 loss: 0.2248756
Epoch: 36, Batch: 748, accuracy: 0.9675192 loss: 0.2221416
Epoch: 37, Batch: 766, accuracy: 0.9678129 loss: 0.2193327
Epoch: 38, Batch: 734, accuracy: 0.9685587 loss: 0.2167726
Epoch: 39, Batch: 752, accuracy: 0.9688330 loss: 0.2141434
Epoch: 40, Batch: 770, accuracy: 0.9693985 loss: 0.2115858
Epoch: 41, Batch: 738, accuracy: 0.9700186 loss: 0.2092509
Epoch: 42, Batch: 756, accuracy: 0.9707934 loss: 0.2068441
Epoch: 43, Batch: 774, accuracy: 0.9712097 loss: 0.2045066
Epoch: 44, Batch: 742, accuracy: 0.9717993 loss: 0.2023690
Epoch: 45, Batch: 760, accuracy: 0.9725690 loss: 0.2001456
Epoch: 46, Batch: 778, accuracy: 0.9727816 loss: 0.1979993
Epoch: 47, Batch: 746, accuracy: 0.9734145 loss: 0.1960264
Epoch: 48, Batch: 764, accuracy: 0.9741013 loss: 0.1939643
Epoch: 49, Batch: 732, accuracy: 0.9745694 loss: 0.1920958
Epoch: 50, Batch: 750, accuracy: 0.9745964 loss: 0.1901533
Epoch: 51, Batch: 768, accuracy: 0.9749472 loss: 0.1882502
Epoch: 52, Batch: 736, accuracy: 0.9754919 loss: 0.1865019
Epoch: 53, Batch: 754, accuracy: 0.9756415 loss: 0.1846913
Epoch: 54, Batch: 772, accuracy: 0.9759258 loss: 0.1829194
Epoch: 55, Batch: 740, accuracy: 0.9765308 loss: 0.1812944
Epoch: 56, Batch: 758, accuracy: 0.9768610 loss: 0.1795969
Epoch: 57, Batch: 776, accuracy: 0.9771557 loss: 0.1779473
Epoch: 58, Batch: 744, accuracy: 0.9777055 loss: 0.1764243
Epoch: 59, Batch: 762, accuracy: 0.9780676 loss: 0.1748292
Epoch: 60, Batch: 780, accuracy: 0.9780930 loss: 0.1732857
Epoch: 61, Batch: 748, accuracy: 0.9786382 loss: 0.1718533
Epoch: 62, Batch: 766, accuracy: 0.9790580 loss: 0.1703606
Epoch: 63, Batch: 734, accuracy: 0.9794643 loss: 0.1689844
Epoch: 64, Batch: 752, accuracy: 0.9795194 loss: 0.1675558
Epoch: 65, Batch: 770, accuracy: 0.9798355 loss: 0.1661485
Epoch: 66, Batch: 738, accuracy: 0.9801252 loss: 0.1648517
Epoch: 67, Batch: 756, accuracy: 0.9805565 loss: 0.1635002
Epoch: 68, Batch: 774, accuracy: 0.9807661 loss: 0.1621738
Epoch: 69, Batch: 742, accuracy: 0.9811155 loss: 0.1609511
Epoch: 70, Batch: 760, accuracy: 0.9815005 loss: 0.1596673
Epoch: 71, Batch: 778, accuracy: 0.9816672 loss: 0.1584153
Epoch: 72, Batch: 746, accuracy: 0.9819277 loss: 0.1572576
Epoch: 73, Batch: 764, accuracy: 0.9822713 loss: 0.1560377
Epoch: 74, Batch: 732, accuracy: 0.9825205 loss: 0.1549213
Epoch: 75, Batch: 750, accuracy: 0.9825649 loss: 0.1537535
Epoch: 76, Batch: 768, accuracy: 0.9829730 loss: 0.1526004
Epoch: 77, Batch: 736, accuracy: 0.9833574 loss: 0.1515341
Epoch: 78, Batch: 754, accuracy: 0.9834851 loss: 0.1504219
Epoch: 79, Batch: 772, accuracy: 0.9836878 loss: 0.1493252
Epoch: 80, Batch: 740, accuracy: 0.9838900 loss: 0.1483142
Epoch: 81, Batch: 758, accuracy: 0.9842309 loss: 0.1472510
Epoch: 82, Batch: 776, accuracy: 0.9844353 loss: 0.1462108
Epoch: 83, Batch: 744, accuracy: 0.9845008 loss: 0.1452462
Epoch: 84, Batch: 762, accuracy: 0.9848050 loss: 0.1442298
Epoch: 85, Batch: 780, accuracy: 0.9846950 loss: 0.1432384
Epoch: 86, Batch: 748, accuracy: 0.9849591 loss: 0.1423156
Epoch: 87, Batch: 766, accuracy: 0.9850676 loss: 0.1413477
Epoch: 88, Batch: 734, accuracy: 0.9853529 loss: 0.1404511
Epoch: 89, Batch: 752, accuracy: 0.9855163 loss: 0.1395154
Epoch: 90, Batch: 770, accuracy: 0.9856112 loss: 0.1385893
Epoch: 91, Batch: 738, accuracy: 0.9858550 loss: 0.1377320
Epoch: 92, Batch: 756, accuracy: 0.9861088 loss: 0.1368340
Epoch: 93, Batch: 774, accuracy: 0.9861693 loss: 0.1359484
Epoch: 94, Batch: 742, accuracy: 0.9863518 loss: 0.1351287
Epoch: 95, Batch: 760, accuracy: 0.9865925 loss: 0.1342648
Epoch: 96, Batch: 778, accuracy: 0.9866616 loss: 0.1334170
Epoch: 97, Batch: 746, accuracy: 0.9867805 loss: 0.1326308
Epoch: 98, Batch: 764, accuracy: 0.9870507 loss: 0.1317996
Epoch: 99, Batch: 732, accuracy: 0.9872954 loss: 0.1310349
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[128]
      EXLA.Backend
      [0.05652742460370064, 0.08177822828292847, -0.02138412743806839, -0.02390858717262745, -0.04634668678045273, -0.07943019270896912, -0.016626596450805664, -0.014321718364953995, -0.03653227165341377, 0.028186671435832977, 0.10801924765110016, 0.10395485907793045, 0.023988498374819756, 0.11345196515321732, -0.03242427483201027, -4.138844378758222e-4, 0.08783016353845596, 0.11951909959316254, 0.006448953412473202, 0.0258331336081028, 0.02526051551103592, 0.05611482262611389, 0.12466409057378769, 0.0274056326597929, 0.09361684322357178, 4.14769136114046e-4, 0.08059605956077576, 0.10137662291526794, 0.0781736969947815, -0.08695081621408463, 0.017613355070352554, 0.1364990621805191, -0.026523269712924957, 0.05165987089276314, -0.09639599174261093, -0.02722962200641632, 0.029649771749973297, 0.044920168817043304, 0.011372238397598267, 0.04830102622509003, -0.06037721037864685, -0.02536039985716343, 9.339514654129744e-4, -0.0505872406065464, -0.0211349930614233, 0.05413465201854706, -0.0653940960764885, 0.039978839457035065, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[784][128]
      EXLA.Backend
      [
        [-0.07978315651416779, -0.04933007061481476, -0.07117883116006851, 0.0046152397990226746, 0.006489121355116367, 0.06164148822426796, -0.07691133767366409, 0.07299374788999557, -0.07805027067661285, 0.0654342919588089, -0.007943245582282543, 0.005590663757175207, 0.06408333778381348, -0.012512286193668842, -0.01217877771705389, 0.0249907486140728, -0.03131278604269028, 0.008503668941557407, 0.05303604155778885, -0.07943467795848846, -0.023482846096158028, -0.060498885810375214, 0.04801912605762482, -0.06532831490039825, -0.04819829389452934, 0.041818514466285706, 0.04748879373073578, 0.07412124425172806, 0.06725671142339706, -0.037834879010915756, 0.05654172599315643, 0.031005367636680603, -0.07455082982778549, -0.057416882365942, 0.05726246535778046, -0.07380787283182144, -0.015585354529321194, 0.027122197672724724, 0.018614524975419044, 0.05705614387989044, -0.026843978092074394, -0.02302541770040989, -0.0020782870706170797, 0.02832561917603016, -0.024643316864967346, -0.032047487795352936, 0.010319806635379791, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[10]
      EXLA.Backend
      [-0.1862199306488037, 0.06070087477564812, -0.02205681800842285, -0.14590252935886383, 0.062407322227954865, 0.18080928921699524, 0.014985051937401295, 0.07054023444652557, -0.05126854404807091, 0.016002867370843887]
    >,
    "kernel" => #Nx.Tensor<
      f32[128][10]
      EXLA.Backend
      [
        [-0.17434200644493103, 0.3851516842842102, 0.8101922273635864, -0.6151928901672363, 0.13069160282611847, -0.5000512003898621, -0.4076564908027649, -0.23182162642478943, 0.40456047654151917, 0.47900182008743286],
        [0.469640851020813, -0.1989802122116089, -0.10294214636087418, -0.5463757514953613, 0.5386794209480286, -0.015668772161006927, 0.0581229142844677, 0.2237190157175064, 0.32730910181999207, -0.43035635352134705],
        [0.19022969901561737, -0.1554863601922989, -0.193527489900589, -0.092081218957901, -0.146250918507576, 0.1760367453098297, -0.14805874228477478, -0.5237136483192444, 0.4893774688243866, 0.18707652390003204],
        [-0.10430952906608582, 0.32146576046943665, 0.05426676571369171, 0.4475323557853699, -0.18521767854690552, -0.07551898062229156, 0.015840010717511177, -0.3172309398651123, 0.33430740237236023, -0.051861248910427094],
        [-0.06059794872999191, 0.31996384263038635, 0.13508550822734833, -0.3562682867050171, 0.2436303198337555, 0.2505604028701782, ...],
        ...
      ]
    >
  }
}

Evaluation

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)

22:40:58.122 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 156, accuracy: 0.9755175
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.9755175113677979
    >
  }
}

Executing The Model

{test_batch, _} = Enum.at(test_data, 0)
{#Nx.Tensor<
   f32[64][784]
   EXLA.Backend
   [
     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   u8[64][10]
   EXLA.Backend
   [
     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
     [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
     [0, 0, 0, 0, 0, 0, 1, 0, ...],
     ...
   ]
 >}
test_image = test_batch[0]
#Nx.Tensor<
  f32[784]
  EXLA.Backend
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
>
test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
#Nx.Heatmap<
  f32[28][28]
  
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
>
{_, predict_fn} = Axon.build(model, compiler: EXLA)
{#Function<134.37423472/2 in Nx.Defn.Compiler.fun/2>,
 #Function<134.37423472/2 in Nx.Defn.Compiler.fun/2>}
probabilities =
  test_image
  |> Nx.new_axis(0)
  |> then(&amp;predict_fn.(trained_model_state, &amp;1))
  |> Nx.argmax()
#Nx.Tensor<
  s64
  EXLA.Backend
  3
>