Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Your first Axon model


Your first Axon model

  {:axon, "~> 0.6"},
  {:nx, "~> 0.6", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:kino, "~> 0.11"}

Your first model

Axon is a library for creating and training neural networks in Elixir. Everything in Axon centers around the %Axon{} struct which represents an instance of an Axon model.

Models are just graphs which represent the transformation and flow of input data to a desired output. Really, you can think of models as representing a single computation or function. An Axon model, when executed, takes data as input and returns transformed data as output.

All Axon models start with a declaration of input nodes. These are the root nodes of your computation graph, and correspond to the actual input data you want to send to Axon:

input = Axon.input("data")

Technically speaking, input is now a valid Axon model which you can inspect, execute, and initialize. You can visualize how data flows through the graph using Axon.Display.as_graph/2:

template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(input, template)

Notice the execution flow is just a single node, because your graph only consists of an input node! You pass data in and the model spits the same data back out, without any intermediate transformations.

You can see this in action by actually executing your model. You can build the %Axon{} struct into it’s initialization and forward functions by calling Axon.build/2. This pattern of “lowering” or transforming the %Axon{} data structure into other functions or representations is very common in Axon. By simply traversing the data structure, you can create useful functions, execution visualizations, and more!

{init_fn, predict_fn} = Axon.build(input)

Notice that Axon.build/2 returns a tuple of {init_fn, predict_fn}. init_fn has the signature:

init_fn.(template :: map(tensor) | tensor, initial_params :: map) :: map(tensor)

while predict_fn has the signature:

predict_fn.(params :: map(tensor), input :: map(tensor) | tensor)

init_fn returns all of your model’s trainable parameters and state. You need to pass a template of the expected inputs because the shape of certain model parameters often depend on the shape of model inputs. You also need to pass any initial parameters you want your model to start with. This is useful for things like transfer learning, which you can read about in another guide.

predict_fn returns transformed inputs from your model’s trainable parameters and the given inputs.

params = init_fn.(Nx.template({1, 8}, :f32), %{})

In this example, you use Nx.template/2 to create a template tensor, which is a placeholder that does not actually consume any memory. Templates are useful for initialization because you don’t actually need to know anything about your inputs other than their shape and type.

Notice init_fn returned an empty map because your model does not have any trainable parameters. This should make sense because it’s just an input layer.

Now you can pass these trainable parameters to predict_fn along with some input to actually execute your model:

predict_fn.(params, Nx.iota({1, 8}, type: :f32))

And your model just returned the given input, as expected!