Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Multi-input / multi-output models

multi_input_multi_output_models.livemd

Multi-input / multi-output models

Mix.install([
  {:axon, ">= 0.5.0"},
  {:kino, ">= 0.9.0"}
])
:ok

Creating multi-input models

Sometimes your application necessitates the use of multiple inputs. To use multiple inputs in an Axon model, you just need to declare multiple inputs in your graph:

input_1 = Axon.input("input_1")
input_2 = Axon.input("input_2")

out = Axon.add(input_1, input_2)
#Axon<
  inputs: %{"input_1" => nil, "input_2" => nil}
  outputs: "add_0"
  nodes: 4
>

Notice when you inspect the model, it tells you what your models inputs are up front. You can also get metadata about your model inputs programmatically with Axon.get_inputs/1:

Axon.get_inputs(out)
%{"input_1" => nil, "input_2" => nil}

Each input is uniquely named, so you can pass inputs by-name into inspection and execution functions with a map:

inputs = %{
  "input_1" => Nx.template({2, 8}, :f32),
  "input_2" => Nx.template({2, 8}, :f32)
}

Axon.Display.as_graph(out, inputs)
graph TD;
3[/"input_1 (:input) {2, 8}"/];
4[/"input_2 (:input) {2, 8}"/];
5["container_0 (:container) {{2, 8}, {2, 8}}"];
6["add_0 (:add) {2, 8}"];
5 --> 6;
4 --> 5;
3 --> 5;
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(inputs, %{})
%{}
inputs = %{
  "input_1" => Nx.iota({2, 8}, type: :f32),
  "input_2" => Nx.iota({2, 8}, type: :f32)
}

predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][8]
  [
    [0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0],
    [16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0]
  ]
>

If you forget a required input, Axon will raise:

predict_fn.(params, %{"input_1" => Nx.iota({2, 8}, type: :f32)})

Creating multi-output models

Depending on your application, you might also want your model to have multiple outputs. You can achieve this by using Axon.container/2 to wrap multiple nodes into any supported Nx container:

inp = Axon.input("data")

x1 = inp |> Axon.dense(32) |> Axon.relu()
x2 = inp |> Axon.dense(64) |> Axon.relu()

out = Axon.container({x1, x2})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
7[/"data (:input) {2, 8}"/];
8["dense_0 (:dense) {2, 32}"];
9["relu_0 (:relu) {2, 32}"];
10["dense_1 (:dense) {2, 64}"];
11["relu_1 (:relu) {2, 64}"];
12["container_0 (:container) {{2, 32}, {2, 64}}"];
11 --> 12;
9 --> 12;
10 --> 11;
7 --> 10;
8 --> 9;
7 --> 8;

When executed, containers will return a data structure which matches their input structure:

{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
{#Nx.Tensor<
   f32[2][32]
   [
     [0.4453479051589966, 1.7394963502883911, 0.8509911298751831, 0.35142624378204346, 0.0, 0.0, 0.0, 3.942654609680176, 0.0, 0.0, 0.0, 0.6140655279159546, 0.0, 5.719906330108643, 1.1410939693450928, 0.0, 2.6871578693389893, 3.373258352279663, 0.0, 0.0, 0.0, 0.3058185875415802, 0.0, 0.0, 1.3737146854400635, 2.2648088932037354, 1.3570061922073364, 0.0, 0.05746358633041382, 0.0, 2.046199321746826, 4.884631156921387],
     [0.0, 2.0598671436309814, 2.4343056678771973, 3.2341041564941406, 0.0, 1.905256748199463, 0.0, 12.712749481201172, 0.0, 0.0, 0.0, 4.559232711791992, 0.0, 12.027459144592285, 0.8423471450805664, 0.0, 8.888325691223145, ...]
   ]
 >,
 #Nx.Tensor<
   f32[2][64]
   [
     [2.211906909942627, 0.937014639377594, 0.017132893204689026, 0.0, 3.617021083831787, 1.3125507831573486, 1.1870051622390747, 0.0, 0.0, 1.245000958442688, 1.5268664360046387, 0.0, 2.16796612739563, 0.8091188669204712, 0.45314761996269226, 0.0, 0.05176612734794617, 0.0, 5.982738018035889, 1.58057701587677, 0.0, 0.0, 1.2986125946044922, 0.8577098250389099, 0.0, 1.1064631938934326, 1.1242716312408447, 1.8777625560760498, 3.4422712326049805, 0.13321448862552643, 2.753225088119507, 0.0, 0.45021766424179077, 0.5664225816726685, 0.0, 0.0, 0.0, 1.5448659658432007, 0.0, 0.7237715721130371, 0.1693495213985443, 0.0, 0.719341516494751, 0.0, 0.0, 4.644839763641357, 0.0, 3.597681760787964, ...],
     ...
   ]
 >}

You can output maps as well:

out = Axon.container(%{x1: x1, x2: x2})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
%{
  x1: #Nx.Tensor<
    f32[2][32]
    [
      [1.4180752038955688, 1.8710994720458984, 0.0, 1.1198676824569702, 1.1357430219650269, 0.0, 0.0, 0.0, 2.907017469406128, 0.0, 0.3814663589000702, 0.0, 0.6225995421409607, 1.1952786445617676, 0.0, 3.6701409816741943, 3.581918716430664, 1.4750021696090698, 0.910987377166748, 0.0, 0.0, 0.0, 2.317782402038574, 0.8362345695495605, 0.0, 1.9256348609924316, 0.0, 0.0, 0.0, 1.8028252124786377, 1.448373556137085, 1.743951678276062],
      [3.7401936054229736, 2.494429349899292, 0.0, 0.9745509624481201, 8.416919708251953, 0.0, 0.6044515371322632, 0.0, 2.5829238891601562, 0.0, 3.592892646789551, 0.0, 0.0, 4.004939079284668, 0.0, 9.755555152893066, 5.3506879806518555, ...]
    ]
  >,
  x2: #Nx.Tensor<
    f32[2][64]
    [
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5240116119384766, 0.0, 1.6478428840637207, 0.0, 0.0, 0.0, 0.0, 2.1685361862182617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5010783672332764, 0.36673399806022644, 0.0, 0.0, 0.5610344409942627, 1.9324723482131958, 0.39768826961517334, 0.0, 0.0, 0.0, 0.0, 0.0, 0.054594263434410095, 0.6123883128166199, 0.15942004323005676, 0.7058550715446472, 0.0, 1.860019326210022, 0.2499483972787857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03381317853927612, ...],
      ...
    ]
  >
}

Containers even support arbitrary nesting:

out = Axon.container({%{x1: {x1, x2}, x2: %{x1: x1, x2: {x2}}}})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
{%{
   x1: {#Nx.Tensor<
      f32[2][32]
      [
        [1.7373675107955933, 0.0, 5.150482177734375, 0.544252336025238, 0.275376558303833, 0.0, 0.0, 0.0, 0.0, 1.7849855422973633, 0.7857151031494141, 0.2273893654346466, 0.2701767086982727, 2.321484327316284, 2.685051441192627, 0.0, 2.547382116317749, 0.0, 0.0, 0.0, 0.722919225692749, 2.3600289821624756, 1.4695687294006348, 0.0, 0.0, 0.0, 1.0015852451324463, 1.2762010097503662, 0.0, 0.07927703857421875, 0.0, 0.6216219663619995],
        [4.996878623962402, 0.0, 14.212154388427734, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.517582356929779, 0.0, 2.036062479019165, 2.907236337661743, 8.515787124633789, 7.998186111450195, ...]
      ]
    >,
    #Nx.Tensor<
      f32[2][64]
      [
        [1.2057430744171143, 0.0, 0.0, 0.8717040419578552, 1.7653638124465942, 0.0, 0.0, 0.0, 0.0, 0.9921279549598694, 0.0, 1.0860291719436646, 2.3648557662963867, 0.0, 0.0, 2.0518181324005127, 1.6323723793029785, 0.9113610982894897, 1.6805293560028076, 0.8101096749305725, 0.0, 0.0, 0.0, 2.2150073051452637, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2320713996887207, 0.0, 2.553570508956909, 0.28632092475891113, 0.0, 0.0, 0.020383253693580627, 0.0, 0.2926883101463318, 1.3561311960220337, 0.8884503245353699, 3.1455295085906982, 0.0, 0.0, 1.237722635269165, 0.0, 2.149625539779663, ...],
        ...
      ]
    >},
   x2: %{
     x1: #Nx.Tensor<
       f32[2][32]
       [
         [1.7373675107955933, 0.0, 5.150482177734375, 0.544252336025238, 0.275376558303833, 0.0, 0.0, 0.0, 0.0, 1.7849855422973633, 0.7857151031494141, 0.2273893654346466, 0.2701767086982727, 2.321484327316284, 2.685051441192627, 0.0, 2.547382116317749, 0.0, 0.0, 0.0, 0.722919225692749, 2.3600289821624756, 1.4695687294006348, 0.0, 0.0, 0.0, 1.0015852451324463, 1.2762010097503662, 0.0, 0.07927703857421875, 0.0, 0.6216219663619995],
         [4.996878623962402, 0.0, 14.212154388427734, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.517582356929779, 0.0, 2.036062479019165, 2.907236337661743, 8.515787124633789, ...]
       ]
     >,
     x2: {#Nx.Tensor<
        f32[2][64]
        [
          [1.2057430744171143, 0.0, 0.0, 0.8717040419578552, 1.7653638124465942, 0.0, 0.0, 0.0, 0.0, 0.9921279549598694, 0.0, 1.0860291719436646, 2.3648557662963867, 0.0, 0.0, 2.0518181324005127, 1.6323723793029785, 0.9113610982894897, 1.6805293560028076, 0.8101096749305725, 0.0, 0.0, 0.0, 2.2150073051452637, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2320713996887207, 0.0, 2.553570508956909, 0.28632092475891113, 0.0, 0.0, 0.020383253693580627, 0.0, 0.2926883101463318, 1.3561311960220337, 0.8884503245353699, 3.1455295085906982, 0.0, 0.0, 1.237722635269165, ...],
          ...
        ]
      >}
   }
 }}