Complex models
Mix.install([
{:axon, ">= 0.5.0"},
{:kino, ">= 0.9.0"}
])
:ok
Creating more complex models
Not all models you’d want to create fit cleanly in the sequential paradigm. Some models require a more flexible API. Fortunately, because Axon models are just Elixir data structures, you can manipulate them and decompose architectures as you would any other Elixir program:
input = Axon.input("data")
x1 = input |> Axon.dense(32)
x2 = input |> Axon.dense(64) |> Axon.relu() |> Axon.dense(32)
out = Axon.add(x1, x2)
#Axon<
inputs: %{"data" => nil}
outputs: "add_0"
nodes: 7
>
In the snippet above, your model branches input
into x1
and x2
. Each branch performs a different set of transformations; however, at the end the branches are merged with an Axon.add/3
. You might sometimes see layers like Axon.add/3
called combinators. Really they’re just layers that operate on multiple Axon models at once - typically to merge some branches together.
out
represents your final Axon model.
If you visualize this model, you can see the full effect of the branching in this model:
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
4["dense_0 (:dense) {2, 32}"];
5["dense_1 (:dense) {2, 64}"];
6["relu_0 (:relu) {2, 64}"];
7["dense_2 (:dense) {2, 32}"];
8["container_0 (:container) {{2, 32}, {2, 32}}"];
9["add_0 (:add) {2, 32}"];
8 --> 9;
7 --> 8;
4 --> 8;
6 --> 7;
5 --> 6;
3 --> 5;
3 --> 4;
And you can use Axon.build/2
on out
as you would any other Axon model:
{init_fn, predict_fn} = Axon.build(out)
{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>,
#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>}
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
#Nx.Tensor<
f32[2][32]
[
[-4.283246040344238, 1.8983498811721802, 3.697357654571533, -4.720174789428711, 4.1636152267456055, 1.001131534576416, -0.7027540802955627, -3.7821826934814453, 0.027841567993164062, 9.267499923706055, 3.33616304397583, -1.5465859174728394, 8.983413696289062, 3.7445120811462402, 2.2405576705932617, -3.61336350440979, -1.7320983409881592, 0.5740477442741394, -0.22006472945213318, -0.1806044578552246, 1.1092393398284912, -0.29313594102859497, -0.41948509216308594, 3.526411533355713, -0.9127179384231567, 1.8373844623565674, 1.1746022701263428, -0.6885149478912354, -1.4326229095458984, -1.3498257398605347, -5.803186416625977, 1.5204020738601685],
[-15.615742683410645, 6.555544853210449, 7.033155918121338, -12.33556842803955, 14.105436325073242, -4.230871200561523, 5.985136032104492, -8.445676803588867, 5.383096694946289, 23.413570404052734, 0.8907639980316162, -1.400709629058838, 19.19326400756836, 13.784171104431152, 9.641424179077148, -8.407038688659668, -5.688483238220215, 4.383636474609375, ...]
]
>
As your architectures grow in complexity, you might find yourself reaching for better abstractions to organize your model creation code. For example, PyTorch models are often organized into nn.Module
. The equivalent of an nn.Module
in Axon is a regular Elixir function. If you’re translating models from PyTorch to Axon, it’s natural to create one Elixir function per nn.Module
.
You should write your models as you would write any other Elixir code - you don’t need to worry about any framework specific constructs:
defmodule MyModel do
def model() do
Axon.input("data")
|> conv_block()
|> Axon.flatten()
|> dense_block()
|> dense_block()
|> Axon.dense(1)
end
defp conv_block(input) do
residual = input
x = input |> Axon.conv(3, padding: :same) |> Axon.mish()
x
|> Axon.add(residual)
|> Axon.max_pool(kernel_size: {2, 2})
end
defp dense_block(input) do
input |> Axon.dense(32) |> Axon.relu()
end
end
{:module, MyModel, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:dense_block, 1}}
model = MyModel.model()
#Axon<
inputs: %{"data" => nil}
outputs: "dense_2"
nodes: 12
>
template = Nx.template({1, 28, 28, 3}, :f32)
Axon.Display.as_graph(model, template)
graph TD;
10[/"data (:input) {1, 28, 28, 3}"/];
11["conv_0 (:conv) {1, 28, 28, 3}"];
12["mish_0 (:mish) {1, 28, 28, 3}"];
13["container_0 (:container) {{1, 28, 28, 3}, {1, 28, 28, 3}}"];
14["add_0 (:add) {1, 28, 28, 3}"];
15["max_pool_0 (:max_pool) {1, 14, 14, 3}"];
16["flatten_0 (:flatten) {1, 588}"];
17["dense_0 (:dense) {1, 32}"];
18["relu_0 (:relu) {1, 32}"];
19["dense_1 (:dense) {1, 32}"];
20["relu_1 (:relu) {1, 32}"];
21["dense_2 (:dense) {1, 1}"];
20 --> 21;
19 --> 20;
18 --> 19;
17 --> 18;
16 --> 17;
15 --> 16;
14 --> 15;
13 --> 14;
10 --> 13;
12 --> 13;
11 --> 12;
10 --> 11;