Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Getting Deep with Axon

GettingDeepWithAxon.livemd

Getting Deep with Axon

Mix.install([
  {:axon, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:scidata, "~> 0.1"},
  {:kino, "~> 0.8"},
  {:table_rex, "~> 3.1.1"}
])

Nx.default_backend(EXLA.Backend)

Using Nx to Create a Simple Neural Network

defmodule NeuralNetwork do
  import Nx.Defn

  defn dense(input, weight, bias) do
    input
    |> Nx.dot(weight)
    |> Nx.add(bias)
  end

  defn activation(input) do
    Nx.sigmoid(input)
  end

  defn hidden(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  defn output(input, weight, bias) do
    input
    |> dense(weight, bias)
    |> activation()
  end

  defn predict(input, w1, b1, w2, b2) do
    input
    |> hidden(w1, b1)
    |> output(w2, b2)
  end
end
{:module, NeuralNetwork, <<70, 79, 82, 49, 0, 0, 18, ...>>, true}
key = Nx.Random.key(42)
{w1, new_key} = Nx.Random.uniform(key)
{b1, new_key} = Nx.Random.uniform(new_key)
{w2, new_key} = Nx.Random.uniform(new_key)
{b2, new_key} = Nx.Random.uniform(new_key)
{#Nx.Tensor<
   f32
   EXLA.Backend
   0.6716941595077515
 >,
 #Nx.Tensor<
   u32[2]
   EXLA.Backend
   [4249898905, 2425127087]
 >}
{input, _new_key} = Nx.Random.uniform(new_key)

input
|> NeuralNetwork.predict(w1, b1, w2, b2)
#Nx.Tensor<
  f32
  EXLA.Backend
  0.6635995507240295
>

Working with the Data

{images, labels} = Scidata.MNIST.download()
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}},
 {<<5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8,
    6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, ...>>, {:u, 8}, {60000}}}
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels

images =
  image_data
  |> Nx.from_binary(image_type)
  |> Nx.divide(255)
  |> Nx.reshape({60000, :auto})

labels =
  label_data
  |> Nx.from_binary(label_type)
  |> Nx.reshape(label_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 10}))
#Nx.Tensor<
  u8[60000][10]
  EXLA.Backend
  [
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    ...
  ]
>
train_range = 0..49_999//1
test_range = 50_000..-1//1

train_images = images[train_range]
train_labels = labels[train_range]

test_images = images[test_range]
test_labels = labels[test_range]
#Nx.Tensor<
  u8[10000][10]
  EXLA.Backend
  [
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    ...
  ]
>
batch_size = 64

train_data =
  train_images
  |> Nx.to_batched(batch_size)
  |> Stream.zip(Nx.to_batched(train_labels, batch_size))

test_data =
  test_images
  |> Nx.to_batched(batch_size)
  |> Stream.zip(Nx.to_batched(test_labels, batch_size))
#Function<76.38948127/2 in Stream.zip_with/2>

Building the Model

model =
  Axon.input("images", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)
#Axon<
  inputs: %{"images" => {nil, 784}}
  outputs: "softmax_0"
  nodes: 5
>
Axon.Display.as_graph(model, Nx.template({1, 784}, :f32))
graph TD;
1[/"images (:input) {1, 784}"/];
2["dense_0 (:dense) {1, 128}"];
3["relu_0 (:relu) {1, 128}"];
4["dense_1 (:dense) {1, 10}"];
5["softmax_0 (:softmax) {1, 10}"];
4 --> 5;
3 --> 4;
2 --> 3;
1 --> 2;
template = Nx.template({1, 784}, :f32)

Axon.Display.as_table(model, template)
|> IO.puts()
warning: using map.field notation (without parentheses) to invoke function TableRex.Renderer.Text.default_options() is deprecated, you must add parentheses instead: remote.function()
  (table_rex 3.1.1) lib/table_rex/table.ex:274: TableRex.Table.render/2
  (table_rex 3.1.1) lib/table_rex/table.ex:292: TableRex.Table.render!/2
  (axon 0.6.1) lib/axon/display.ex:51: Axon.Display.as_table/2
  (elixir 1.17.2) src/elixir.erl:386: :elixir.eval_external_handler/3
  (stdlib 6.0) erl_eval.erl:904: :erl_eval.do_apply/7

+-----------------------------------------------------------------------------------------------------------+
|                                                   Model                                                   |
+==================================+=============+==============+===================+=======================+
| Layer                            | Input Shape | Output Shape | Options           | Parameters            |
+==================================+=============+==============+===================+=======================+
| images ( input )                 | []          | {1, 784}     | shape: {nil, 784} |                       |
|                                  |             |              | optional: false   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["images"] )      | [{1, 784}]  | {1, 128}     |                   | kernel: f32[784][128] |
|                                  |             |              |                   | bias: f32[128]        |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| relu_0 ( relu["dense_0"] )       | [{1, 128}]  | {1, 128}     |                   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["relu_0"] )      | [{1, 128}]  | {1, 10}      |                   | kernel: f32[128][10]  |
|                                  |             |              |                   | bias: f32[10]         |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_1"] ) | [{1, 10}]   | {1, 10}      |                   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
Total Parameters: 101770
Total Parameters Memory: 407080 bytes
:ok
IO.inspect(model, structs: false)
%{
  output: 5,
  nodes: %{
    1 => %{
      args: [],
      id: 1,
      name: #Function<69.29713129/2 in Axon.name/2>,
      parent: [],
      mode: :both,
      opts: [shape: {nil, 784}, optional: false],
      op: :input,
      stacktrace: [
        {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
        {:elixir, :eval_external_handler, 3,
         [file: ~c"src/elixir.erl", line: 386]},
        {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
        {:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]},
        {:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]}
      ],
      __struct__: Axon.Node,
      hooks: [],
      parameters: [],
      op_name: :input,
      policy: %{
        output: {:f, 32},
        params: {:f, 32},
        __struct__: Axon.MixedPrecision.Policy,
        compute: {:f, 32}
      }
    },
    2 => %{
      args: [:layer, :parameter, :parameter],
      id: 2,
      name: #Function<68.29713129/2 in Axon.name/2>,
      parent: [1],
      mode: :both,
      opts: [],
      op: :dense,
      stacktrace: [
        {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
        {Axon, :dense, 3, [file: ~c"lib/axon.ex", line: 816]},
        {:elixir, :eval_external_handler, 3,
         [file: ~c"src/elixir.erl", line: 386]},
        {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
        {:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]}
      ],
      __struct__: Axon.Node,
      hooks: [],
      parameters: [
        %{
          name: "kernel",
          type: {:f, 32},
          __struct__: Axon.Parameter,
          children: nil,
          shape: #Function<28.29713129/1 in Axon.dense/3>,
          initializer: #Function<3.47872556/3 in Axon.Initializers.glorot_uniform/1>,
          frozen: false
        },
        %{
          name: "bias",
          type: {:f, 32},
          __struct__: Axon.Parameter,
          children: nil,
          shape: #Function<29.29713129/1 in Axon.dense/3>,
          initializer: #Function<23.47872556/2 in Axon.Initializers.zeros/0>,
          frozen: false
        }
      ],
      op_name: :dense,
      policy: %{
        output: {:f, 32},
        params: {:f, 32},
        __struct__: Axon.MixedPrecision.Policy,
        compute: {:f, 32}
      }
    },
    3 => %{
      args: [:layer],
      id: 3,
      name: #Function<68.29713129/2 in Axon.name/2>,
      parent: [2],
      mode: :both,
      opts: [],
      op: :relu,
      stacktrace: [
        {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
        {:elixir, :eval_external_handler, 3,
         [file: ~c"src/elixir.erl", line: 386]},
        {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
        {:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]}
      ],
      __struct__: Axon.Node,
      hooks: [],
      parameters: [],
      op_name: :relu,
      policy: %{
        output: {:f, 32},
        params: {:f, 32},
        __struct__: Axon.MixedPrecision.Policy,
        compute: {:f, 32}
      }
    },
    4 => %{
      args: [:layer, :parameter, :parameter],
      id: 4,
      name: #Function<68.29713129/2 in Axon.name/2>,
      parent: [3],
      mode: :both,
      opts: [],
      op: :dense,
      stacktrace: [
        {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
        {Axon, :dense, 3, [file: ~c"lib/axon.ex", line: 816]},
        {:elixir, :eval_external_handler, 3,
         [file: ~c"src/elixir.erl", line: 386]},
        {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]},
        {:elixir, :eval_forms, 4, [file: ~c"src/elixir.erl", line: 364]}
      ],
      __struct__: Axon.Node,
      hooks: [],
      parameters: [
        %{
          name: "kernel",
          type: {:f, 32},
          __struct__: Axon.Parameter,
          children: nil,
          shape: #Function<28.29713129/1 in Axon.dense/3>,
          initializer: #Function<3.47872556/3 in Axon.Initializers.glorot_uniform/1>,
          frozen: false
        },
        %{
          name: "bias",
          type: {:f, 32},
          __struct__: Axon.Parameter,
          children: nil,
          shape: #Function<29.29713129/1 in Axon.dense/3>,
          initializer: #Function<23.47872556/2 in Axon.Initializers.zeros/0>,
          frozen: false
        }
      ],
      op_name: :dense,
      policy: %{
        output: {:f, 32},
        params: {:f, 32},
        __struct__: Axon.MixedPrecision.Policy,
        compute: {:f, 32}
      }
    },
    5 => %{
      args: [:layer],
      id: 5,
      name: #Function<68.29713129/2 in Axon.name/2>,
      parent: [4],
      mode: :both,
      opts: [],
      op: :softmax,
      stacktrace: [
        {Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
        {:elixir, :eval_external_handler, 3,
         [file: ~c"src/elixir.erl", line: 386]},
        {:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
        {:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]},
        {:elixir, :eval_forms, 4, [file: ~c"src/elixir.erl", line: 364]},
        {Module.ParallelChecker, :verify, 1,
         [file: ~c"lib/module/parallel_checker.ex", line: 112]}
      ],
      __struct__: Axon.Node,
      hooks: [],
      parameters: [],
      op_name: :softmax,
      policy: %{
        output: {:f, 32},
        params: {:f, 32},
        __struct__: Axon.MixedPrecision.Policy,
        compute: {:f, 32}
      }
    }
  },
  __struct__: Axon
}
#Axon<
  inputs: %{"images" => {nil, 784}}
  outputs: "softmax_0"
  nodes: 5
>

Training the Model

trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)

16:43:31.979 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 750, accuracy: 0.7699941 loss: 0.9710366
Epoch: 1, Batch: 768, accuracy: 0.8787183 loss: 0.7062089
Epoch: 2, Batch: 736, accuracy: 0.8953105 loss: 0.6002396
Epoch: 3, Batch: 754, accuracy: 0.9048012 loss: 0.5338430
Epoch: 4, Batch: 772, accuracy: 0.9116874 loss: 0.4894916
Epoch: 5, Batch: 740, accuracy: 0.9164769 loss: 0.4588849
Epoch: 6, Batch: 758, accuracy: 0.9210311 loss: 0.4331421
Epoch: 7, Batch: 776, accuracy: 0.9243082 loss: 0.4124633
Epoch: 8, Batch: 744, accuracy: 0.9279363 loss: 0.3959939
Epoch: 9, Batch: 762, accuracy: 0.9306808 loss: 0.3807933
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[128]
      EXLA.Backend
      [-0.04557688534259796, 0.056884266436100006, -0.0010324494214728475, 0.07211018353700638, -0.023117363452911377, -0.010195941664278507, -0.09574069082736969, -0.014177899807691574, 0.040640439838171005, 4.5957943075336516e-4, 0.0774596557021141, 0.01725124940276146, -0.01311414036899805, -0.0601215660572052, 0.055294282734394073, -0.03480232506990433, -0.028218824416399002, -8.242334588430822e-4, -0.04341563209891319, 0.0864398404955864, 0.04911106452345848, 0.046631887555122375, -0.03658245503902435, -0.01592755876481533, -0.0050981719978153706, 0.014501607045531273, 0.013695931993424892, 0.02535100094974041, 0.0402829684317112, 0.0329509861767292, 0.0674162432551384, 0.10294120758771896, 0.03687436133623123, 0.09136985242366791, -0.050792284309864044, -0.016392890363931656, 0.040335074067115784, 0.018260926008224487, 0.019201647490262985, 0.002118457341566682, -0.012751685455441475, 0.05895159766077995, -0.01829197071492672, -0.01859128475189209, 0.028066368773579597, 0.0074777971021831036, -0.037080198526382446, -0.010457291267812252, ...]
    >,
    "kernel" => #Nx.Tensor<
      f32[784][128]
      EXLA.Backend
      [
        [0.016736775636672974, -0.021791711449623108, -0.018893131986260414, 0.055778056383132935, 0.037308219820261, 0.06802469491958618, 0.06937356293201447, 0.05562176555395126, 0.01878305897116661, -0.056212782859802246, 0.04226350784301758, -0.006629091687500477, -0.01020803116261959, -0.0809481143951416, 0.008785022422671318, -0.03413747623562813, -0.03490311652421951, -0.07281620055437088, -0.03968954086303711, 0.05899643525481224, 0.04544154182076454, -0.006005663890391588, 0.05451308190822601, 0.06520316004753113, -0.02622944489121437, -0.0057079121470451355, 0.05863101780414581, 0.018925368785858154, 0.07377712428569794, -0.010299346409738064, -0.026377711445093155, 0.05731591582298279, -0.039092276245355606, 0.025120161473751068, 0.038714949041604996, 0.04092622548341751, 0.0184006430208683, 0.002805600641295314, 0.022019265219569206, 0.08051205426454544, 0.06950380653142929, -0.07453489303588867, 0.07161813974380493, 0.061429694294929504, -0.06079089269042015, -0.0066459933295845985, 0.06742016226053238, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[10]
      EXLA.Backend
      [-0.06760766357183456, 0.09135901182889938, 0.042458560317754745, -0.047845494002103806, -0.004332108423113823, 0.21402733027935028, -0.039574798196554184, 0.08195860683917999, -0.22631974518299103, -0.0441216416656971]
    >,
    "kernel" => #Nx.Tensor<
      f32[128][10]
      EXLA.Backend
      [
        [-0.14514467120170593, -0.21997107565402985, 0.13562583923339844, 0.09718991816043854, -0.09510929137468338, 0.031382348388433456, 0.1758842170238495, -0.3963959217071533, 0.13938012719154358, 0.09729427844285965],
        [-0.0751059427857399, -0.05492069572210312, -0.12499954551458359, -0.3047340214252472, 0.23491673171520233, 0.3278830051422119, -0.09353523701429367, 0.14774689078330994, -0.08367066830396652, 0.04616687074303627],
        [-0.22146247327327728, -0.09685598313808441, -0.2944234311580658, 0.19600969552993774, 0.17125526070594788, 0.2786778509616852, 0.15523986518383026, -0.14579492807388306, 0.17253002524375916, -0.20731164515018463],
        [0.1843968778848648, -0.15169936418533325, -0.09471436589956284, -0.13887140154838562, 0.18205128610134125, -0.06520683318376541, -0.09642955660820007, 0.3404451906681061, -0.2211306095123291, 0.13326618075370789],
        [0.18364790081977844, -0.05975278094410896, -0.1784059703350067, -0.0626249834895134, 0.0054058595560491085, -0.16089509427547455, ...],
        ...
      ]
    >
  }
}

Evaluating the Model

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)

16:43:40.272 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 156, accuracy: 0.9359077
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.9359076619148254
    >
  }
}

Executing Models with Axon

[{test_batch, _}] = Enum.take(test_data, 1)
[
  {#Nx.Tensor<
     f32[64][784]
     EXLA.Backend
     [
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
       ...
     ]
   >,
   #Nx.Tensor<
     u8[64][10]
     EXLA.Backend
     [
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1, ...],
       ...
     ]
   >}
]
test_image = test_batch[1]
#Nx.Tensor<
  f32[784]
  EXLA.Backend
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
>
test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
#Nx.Heatmap<
  f32[28][28]
  
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
                              
>
{_, predict_fn} = Axon.build(model, compiler: EXLA)
{#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>,
 #Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>}
# Uncomment and run and this will raise!
predict_fn.(trained_model_state, test_image)
probabilities =
  test_image
  |> Nx.new_axis(0)
  |> then(&amp;predict_fn.(trained_model_state, &amp;1))
#Nx.Tensor<
  f32[1][10]
  EXLA.Backend
  [
    [9.18520163395442e-6, 1.9154771871399134e-4, 0.004306578543037176, 0.009376204572618008, 0.002378550823777914, 0.0052066161297261715, 3.823340739472769e-5, 3.9136728446464986e-5, 0.9782620668411255, 1.918784691952169e-4]
  ]
>
probabilities |> Nx.argmax()
#Nx.Tensor<
  s64
  EXLA.Backend
  8
>