Getting Deep with Axon
Mix.install([
{:axon, "~> 0.5"},
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:scidata, "~> 0.1"},
{:kino, "~> 0.8"},
{:table_rex, "~> 3.1.1"}
])
Nx.default_backend(EXLA.Backend)
Using Nx to Create a Simple Neural Network
defmodule NeuralNetwork do
import Nx.Defn
defn dense(input, weight, bias) do
input
|> Nx.dot(weight)
|> Nx.add(bias)
end
defn activation(input) do
Nx.sigmoid(input)
end
defn hidden(input, weight, bias) do
input
|> dense(weight, bias)
|> activation()
end
defn output(input, weight, bias) do
input
|> dense(weight, bias)
|> activation()
end
defn predict(input, w1, b1, w2, b2) do
input
|> hidden(w1, b1)
|> output(w2, b2)
end
end
{:module, NeuralNetwork, <<70, 79, 82, 49, 0, 0, 18, ...>>, true}
key = Nx.Random.key(42)
{w1, new_key} = Nx.Random.uniform(key)
{b1, new_key} = Nx.Random.uniform(new_key)
{w2, new_key} = Nx.Random.uniform(new_key)
{b2, new_key} = Nx.Random.uniform(new_key)
{#Nx.Tensor<
f32
EXLA.Backend
0.6716941595077515
>,
#Nx.Tensor<
u32[2]
EXLA.Backend
[4249898905, 2425127087]
>}
{input, _new_key} = Nx.Random.uniform(new_key)
input
|> NeuralNetwork.predict(w1, b1, w2, b2)
#Nx.Tensor<
f32
EXLA.Backend
0.6635995507240295
>
Working with the Data
{images, labels} = Scidata.MNIST.download()
{{<<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>, {:u, 8}, {60000, 1, 28, 28}},
{<<5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8,
6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, ...>>, {:u, 8}, {60000}}}
{image_data, image_type, image_shape} = images
{label_data, label_type, label_shape} = labels
images =
image_data
|> Nx.from_binary(image_type)
|> Nx.divide(255)
|> Nx.reshape({60000, :auto})
labels =
label_data
|> Nx.from_binary(label_type)
|> Nx.reshape(label_shape)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.iota({1, 10}))
#Nx.Tensor<
u8[60000][10]
EXLA.Backend
[
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
...
]
>
train_range = 0..49_999//1
test_range = 50_000..-1//1
train_images = images[train_range]
train_labels = labels[train_range]
test_images = images[test_range]
test_labels = labels[test_range]
#Nx.Tensor<
u8[10000][10]
EXLA.Backend
[
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
...
]
>
batch_size = 64
train_data =
train_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(train_labels, batch_size))
test_data =
test_images
|> Nx.to_batched(batch_size)
|> Stream.zip(Nx.to_batched(test_labels, batch_size))
#Function<76.38948127/2 in Stream.zip_with/2>
Building the Model
model =
Axon.input("images", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
#Axon<
inputs: %{"images" => {nil, 784}}
outputs: "softmax_0"
nodes: 5
>
Axon.Display.as_graph(model, Nx.template({1, 784}, :f32))
graph TD;
1[/"images (:input) {1, 784}"/];
2["dense_0 (:dense) {1, 128}"];
3["relu_0 (:relu) {1, 128}"];
4["dense_1 (:dense) {1, 10}"];
5["softmax_0 (:softmax) {1, 10}"];
4 --> 5;
3 --> 4;
2 --> 3;
1 --> 2;
template = Nx.template({1, 784}, :f32)
Axon.Display.as_table(model, template)
|> IO.puts()
warning: using map.field notation (without parentheses) to invoke function TableRex.Renderer.Text.default_options() is deprecated, you must add parentheses instead: remote.function()
(table_rex 3.1.1) lib/table_rex/table.ex:274: TableRex.Table.render/2
(table_rex 3.1.1) lib/table_rex/table.ex:292: TableRex.Table.render!/2
(axon 0.6.1) lib/axon/display.ex:51: Axon.Display.as_table/2
(elixir 1.17.2) src/elixir.erl:386: :elixir.eval_external_handler/3
(stdlib 6.0) erl_eval.erl:904: :erl_eval.do_apply/7
+-----------------------------------------------------------------------------------------------------------+
| Model |
+==================================+=============+==============+===================+=======================+
| Layer | Input Shape | Output Shape | Options | Parameters |
+==================================+=============+==============+===================+=======================+
| images ( input ) | [] | {1, 784} | shape: {nil, 784} | |
| | | | optional: false | |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["images"] ) | [{1, 784}] | {1, 128} | | kernel: f32[784][128] |
| | | | | bias: f32[128] |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| relu_0 ( relu["dense_0"] ) | [{1, 128}] | {1, 128} | | |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["relu_0"] ) | [{1, 128}] | {1, 10} | | kernel: f32[128][10] |
| | | | | bias: f32[10] |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_1"] ) | [{1, 10}] | {1, 10} | | |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
Total Parameters: 101770
Total Parameters Memory: 407080 bytes
:ok
IO.inspect(model, structs: false)
%{
output: 5,
nodes: %{
1 => %{
args: [],
id: 1,
name: #Function<69.29713129/2 in Axon.name/2>,
parent: [],
mode: :both,
opts: [shape: {nil, 784}, optional: false],
op: :input,
stacktrace: [
{Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
{:elixir, :eval_external_handler, 3,
[file: ~c"src/elixir.erl", line: 386]},
{:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
{:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]},
{:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]}
],
__struct__: Axon.Node,
hooks: [],
parameters: [],
op_name: :input,
policy: %{
output: {:f, 32},
params: {:f, 32},
__struct__: Axon.MixedPrecision.Policy,
compute: {:f, 32}
}
},
2 => %{
args: [:layer, :parameter, :parameter],
id: 2,
name: #Function<68.29713129/2 in Axon.name/2>,
parent: [1],
mode: :both,
opts: [],
op: :dense,
stacktrace: [
{Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
{Axon, :dense, 3, [file: ~c"lib/axon.ex", line: 816]},
{:elixir, :eval_external_handler, 3,
[file: ~c"src/elixir.erl", line: 386]},
{:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
{:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]}
],
__struct__: Axon.Node,
hooks: [],
parameters: [
%{
name: "kernel",
type: {:f, 32},
__struct__: Axon.Parameter,
children: nil,
shape: #Function<28.29713129/1 in Axon.dense/3>,
initializer: #Function<3.47872556/3 in Axon.Initializers.glorot_uniform/1>,
frozen: false
},
%{
name: "bias",
type: {:f, 32},
__struct__: Axon.Parameter,
children: nil,
shape: #Function<29.29713129/1 in Axon.dense/3>,
initializer: #Function<23.47872556/2 in Axon.Initializers.zeros/0>,
frozen: false
}
],
op_name: :dense,
policy: %{
output: {:f, 32},
params: {:f, 32},
__struct__: Axon.MixedPrecision.Policy,
compute: {:f, 32}
}
},
3 => %{
args: [:layer],
id: 3,
name: #Function<68.29713129/2 in Axon.name/2>,
parent: [2],
mode: :both,
opts: [],
op: :relu,
stacktrace: [
{Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
{:elixir, :eval_external_handler, 3,
[file: ~c"src/elixir.erl", line: 386]},
{:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
{:erl_eval, :expr_list, 7, [file: ~c"erl_eval.erl", line: 1192]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 610]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]}
],
__struct__: Axon.Node,
hooks: [],
parameters: [],
op_name: :relu,
policy: %{
output: {:f, 32},
params: {:f, 32},
__struct__: Axon.MixedPrecision.Policy,
compute: {:f, 32}
}
},
4 => %{
args: [:layer, :parameter, :parameter],
id: 4,
name: #Function<68.29713129/2 in Axon.name/2>,
parent: [3],
mode: :both,
opts: [],
op: :dense,
stacktrace: [
{Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
{Axon, :dense, 3, [file: ~c"lib/axon.ex", line: 816]},
{:elixir, :eval_external_handler, 3,
[file: ~c"src/elixir.erl", line: 386]},
{:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]},
{:elixir, :eval_forms, 4, [file: ~c"src/elixir.erl", line: 364]}
],
__struct__: Axon.Node,
hooks: [],
parameters: [
%{
name: "kernel",
type: {:f, 32},
__struct__: Axon.Parameter,
children: nil,
shape: #Function<28.29713129/1 in Axon.dense/3>,
initializer: #Function<3.47872556/3 in Axon.Initializers.glorot_uniform/1>,
frozen: false
},
%{
name: "bias",
type: {:f, 32},
__struct__: Axon.Parameter,
children: nil,
shape: #Function<29.29713129/1 in Axon.dense/3>,
initializer: #Function<23.47872556/2 in Axon.Initializers.zeros/0>,
frozen: false
}
],
op_name: :dense,
policy: %{
output: {:f, 32},
params: {:f, 32},
__struct__: Axon.MixedPrecision.Policy,
compute: {:f, 32}
}
},
5 => %{
args: [:layer],
id: 5,
name: #Function<68.29713129/2 in Axon.name/2>,
parent: [4],
mode: :both,
opts: [],
op: :softmax,
stacktrace: [
{Axon, :layer, 3, [file: ~c"lib/axon.ex", line: 338]},
{:elixir, :eval_external_handler, 3,
[file: ~c"src/elixir.erl", line: 386]},
{:erl_eval, :do_apply, 7, [file: ~c"erl_eval.erl", line: 904]},
{:erl_eval, :expr, 6, [file: ~c"erl_eval.erl", line: 648]},
{:elixir, :eval_forms, 4, [file: ~c"src/elixir.erl", line: 364]},
{Module.ParallelChecker, :verify, 1,
[file: ~c"lib/module/parallel_checker.ex", line: 112]}
],
__struct__: Axon.Node,
hooks: [],
parameters: [],
op_name: :softmax,
policy: %{
output: {:f, 32},
params: {:f, 32},
__struct__: Axon.MixedPrecision.Policy,
compute: {:f, 32}
}
}
},
__struct__: Axon
}
#Axon<
inputs: %{"images" => {nil, 784}}
outputs: "softmax_0"
nodes: 5
>
Training the Model
trained_model_state =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)
16:43:31.979 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 750, accuracy: 0.7699941 loss: 0.9710366
Epoch: 1, Batch: 768, accuracy: 0.8787183 loss: 0.7062089
Epoch: 2, Batch: 736, accuracy: 0.8953105 loss: 0.6002396
Epoch: 3, Batch: 754, accuracy: 0.9048012 loss: 0.5338430
Epoch: 4, Batch: 772, accuracy: 0.9116874 loss: 0.4894916
Epoch: 5, Batch: 740, accuracy: 0.9164769 loss: 0.4588849
Epoch: 6, Batch: 758, accuracy: 0.9210311 loss: 0.4331421
Epoch: 7, Batch: 776, accuracy: 0.9243082 loss: 0.4124633
Epoch: 8, Batch: 744, accuracy: 0.9279363 loss: 0.3959939
Epoch: 9, Batch: 762, accuracy: 0.9306808 loss: 0.3807933
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[128]
EXLA.Backend
[-0.04557688534259796, 0.056884266436100006, -0.0010324494214728475, 0.07211018353700638, -0.023117363452911377, -0.010195941664278507, -0.09574069082736969, -0.014177899807691574, 0.040640439838171005, 4.5957943075336516e-4, 0.0774596557021141, 0.01725124940276146, -0.01311414036899805, -0.0601215660572052, 0.055294282734394073, -0.03480232506990433, -0.028218824416399002, -8.242334588430822e-4, -0.04341563209891319, 0.0864398404955864, 0.04911106452345848, 0.046631887555122375, -0.03658245503902435, -0.01592755876481533, -0.0050981719978153706, 0.014501607045531273, 0.013695931993424892, 0.02535100094974041, 0.0402829684317112, 0.0329509861767292, 0.0674162432551384, 0.10294120758771896, 0.03687436133623123, 0.09136985242366791, -0.050792284309864044, -0.016392890363931656, 0.040335074067115784, 0.018260926008224487, 0.019201647490262985, 0.002118457341566682, -0.012751685455441475, 0.05895159766077995, -0.01829197071492672, -0.01859128475189209, 0.028066368773579597, 0.0074777971021831036, -0.037080198526382446, -0.010457291267812252, ...]
>,
"kernel" => #Nx.Tensor<
f32[784][128]
EXLA.Backend
[
[0.016736775636672974, -0.021791711449623108, -0.018893131986260414, 0.055778056383132935, 0.037308219820261, 0.06802469491958618, 0.06937356293201447, 0.05562176555395126, 0.01878305897116661, -0.056212782859802246, 0.04226350784301758, -0.006629091687500477, -0.01020803116261959, -0.0809481143951416, 0.008785022422671318, -0.03413747623562813, -0.03490311652421951, -0.07281620055437088, -0.03968954086303711, 0.05899643525481224, 0.04544154182076454, -0.006005663890391588, 0.05451308190822601, 0.06520316004753113, -0.02622944489121437, -0.0057079121470451355, 0.05863101780414581, 0.018925368785858154, 0.07377712428569794, -0.010299346409738064, -0.026377711445093155, 0.05731591582298279, -0.039092276245355606, 0.025120161473751068, 0.038714949041604996, 0.04092622548341751, 0.0184006430208683, 0.002805600641295314, 0.022019265219569206, 0.08051205426454544, 0.06950380653142929, -0.07453489303588867, 0.07161813974380493, 0.061429694294929504, -0.06079089269042015, -0.0066459933295845985, 0.06742016226053238, ...],
...
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[10]
EXLA.Backend
[-0.06760766357183456, 0.09135901182889938, 0.042458560317754745, -0.047845494002103806, -0.004332108423113823, 0.21402733027935028, -0.039574798196554184, 0.08195860683917999, -0.22631974518299103, -0.0441216416656971]
>,
"kernel" => #Nx.Tensor<
f32[128][10]
EXLA.Backend
[
[-0.14514467120170593, -0.21997107565402985, 0.13562583923339844, 0.09718991816043854, -0.09510929137468338, 0.031382348388433456, 0.1758842170238495, -0.3963959217071533, 0.13938012719154358, 0.09729427844285965],
[-0.0751059427857399, -0.05492069572210312, -0.12499954551458359, -0.3047340214252472, 0.23491673171520233, 0.3278830051422119, -0.09353523701429367, 0.14774689078330994, -0.08367066830396652, 0.04616687074303627],
[-0.22146247327327728, -0.09685598313808441, -0.2944234311580658, 0.19600969552993774, 0.17125526070594788, 0.2786778509616852, 0.15523986518383026, -0.14579492807388306, 0.17253002524375916, -0.20731164515018463],
[0.1843968778848648, -0.15169936418533325, -0.09471436589956284, -0.13887140154838562, 0.18205128610134125, -0.06520683318376541, -0.09642955660820007, 0.3404451906681061, -0.2211306095123291, 0.13326618075370789],
[0.18364790081977844, -0.05975278094410896, -0.1784059703350067, -0.0626249834895134, 0.0054058595560491085, -0.16089509427547455, ...],
...
]
>
}
}
Evaluating the Model
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
16:43:40.272 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Batch: 156, accuracy: 0.9359077
%{
0 => %{
"accuracy" => #Nx.Tensor<
f32
EXLA.Backend
0.9359076619148254
>
}
}
Executing Models with Axon
[{test_batch, _}] = Enum.take(test_data, 1)
[
{#Nx.Tensor<
f32[64][784]
EXLA.Backend
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
...
]
>,
#Nx.Tensor<
u8[64][10]
EXLA.Backend
[
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, ...],
...
]
>}
]
test_image = test_batch[1]
#Nx.Tensor<
f32[784]
EXLA.Backend
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...]
>
test_image
|> Nx.reshape({28, 28})
|> Nx.to_heatmap()
#Nx.Heatmap<
f32[28][28]
>
{_, predict_fn} = Axon.build(model, compiler: EXLA)
{#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>,
#Function<134.70434077/2 in Nx.Defn.Compiler.fun/2>}
# Uncomment and run and this will raise!
predict_fn.(trained_model_state, test_image)
probabilities =
test_image
|> Nx.new_axis(0)
|> then(&predict_fn.(trained_model_state, &1))
#Nx.Tensor<
f32[1][10]
EXLA.Backend
[
[9.18520163395442e-6, 1.9154771871399134e-4, 0.004306578543037176, 0.009376204572618008, 0.002378550823777914, 0.0052066161297261715, 3.823340739472769e-5, 3.9136728446464986e-5, 0.9782620668411255, 1.918784691952169e-4]
]
>
probabilities |> Nx.argmax()
#Nx.Tensor<
s64
EXLA.Backend
8
>