Complex models
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.4.0", github: "elixir-nx/nx", sparse: "nx", override: true},
{:kino, "~> 0.7.0"}
])
Creating more complex models
Not all models you’d want to create fit cleanly in the sequential paradigm. Some models require a more flexible API. Fortunately, because Axon models are just Elixir data structures, you can manipulate them and decompose architectures as you would any other Elixir program:
input = Axon.input("data")
x1 = input |> Axon.dense(32)
x2 = input |> Axon.dense(64) |> Axon.relu() |> Axon.dense(32)
out = Axon.add(x1, x2)
In the snippet above, your model branches input
into x1
and x2
. Each branch performs a different set of transformations; however, at the end the branches are merged with an Axon.add/3
. You might sometimes see layers like Axon.add/3
called combinators. Really they’re just layers that operate on multiple Axon models at once - typically to merge some branches together.
out
represents your final Axon model.
If you visualize this model, you can see the full effect of the branching in this model:
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
And you can use Axon.build/2
on out
as you would any other Axon model:
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
As your architectures grow in complexity, you might find yourself reaching for better abstractions to organize your model creation code. For example, PyTorch models are often organized into nn.Module
. The equivalent of an nn.Module
in Axon is a regular Elixir function. If you’re translating models from PyTorch to Axon, it’s natural to create one Elixir function per nn.Module
.
You should write your models as you would write any other Elixir code - you don’t need to worry about any framework specific constructs:
defmodule MyModel do
def model() do
Axon.input("data")
|> conv_block()
|> Axon.flatten()
|> dense_block()
|> dense_block()
|> Axon.dense(1)
end
defp conv_block(input) do
residual = input
x = input |> Axon.conv(3, padding: :same) |> Axon.mish()
x
|> Axon.add(residual)
|> Axon.max_pool(kernel_size: {2, 2})
end
defp dense_block(input) do
input |> Axon.dense(32) |> Axon.relu()
end
end
model = MyModel.model()
template = Nx.template({1, 28, 28, 3}, :f32)
Axon.Display.as_graph(model, template)