Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Converting ONNX Models To Axon

convert_onnx_to_axon.livemd

Converting ONNX Models To Axon

Mix.install([
  {:axon, "~> 0.5.1"},
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:axon_onnx, "~> 0.4.0"},
  {:torchx, github: "elixir-nx/nx", sparse: "torchx"},
  {:stb_image, "~> 0.6.1"},
  {:kino, "~> 0.9.0"}
])
* Getting nx (https://github.com/elixir-nx/nx.git)
remote: Enumerating objects: 17480, done.        
remote: Counting objects: 100% (17480/17480), done.        
remote: Compressing objects: 100% (5431/5431), done.        
remote: Total 17480 (delta 11778), reused 17318 (delta 11662), pack-reused 0        
origin/HEAD set to main
* Getting torchx (https://github.com/elixir-nx/nx.git)
remote: Enumerating objects: 17480, done.        
remote: Counting objects: 100% (17480/17480), done.        
remote: Compressing objects: 100% (5424/5424), done.        
remote: Total 17480 (delta 11786), reused 17316 (delta 11669), pack-reused 0        
origin/HEAD set to main
Resolving Hex dependencies...
Dependency resolution completed:
New:
  axon 0.5.1
  axon_onnx 0.4.0
  castore 1.0.1
  cc_precompiler 0.1.7
  complex 0.5.0
  decimal 2.0.0
  dll_loader_helper 0.1.10
  elixir_make 0.7.6
  kino 0.9.0
  protox 1.6.10
  stb_image 0.6.1
  table 0.1.2
  telemetry 1.2.1
* Getting axon (Hex package)
* Getting axon_onnx (Hex package)
* Getting stb_image (Hex package)
* Getting kino (Hex package)
* Getting table (Hex package)
* Getting cc_precompiler (Hex package)
* Getting elixir_make (Hex package)
* Getting dll_loader_helper (Hex package)
* Getting castore (Hex package)
* Getting protox (Hex package)
* Getting decimal (Hex package)
* Getting complex (Hex package)
* Getting telemetry (Hex package)
==> decimal
Compiling 4 files (.ex)
Generated decimal app
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> protox
Compiling 53 files (.ex)
Generated protox app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 39 files (.ex)
Generated kino app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> axon_onnx
Compiling 28 files (.ex)
warning: Nx.power/2 is deprecated. Use pow/2 instead
  lib/axon_onnx/deserialize.ex:364: AxonOnnx.Deserialize

warning: Nx.power/2 is deprecated. Use pow/2 instead
Invalid call found at 4 locations:
  lib/axon_onnx/shared.ex:44: AxonOnnx.Shared."__defn:sumsquare__"/2
  lib/axon_onnx/shared.ex:52: AxonOnnx.Shared."__defn:l2_norm__"/2
  lib/axon_onnx/shared.ex:63: AxonOnnx.Shared."__defn:lrn__"/2
  lib/axon_onnx/shared.ex:65: AxonOnnx.Shared."__defn:lrn__"/2

warning: Nx.Defn.Kernel.transform/2 is deprecated. Use deftransform/2 or deftransformp/2 from Nx.Defn instead
Invalid call found at 2 locations:
  lib/axon_onnx/shared.ex:58: AxonOnnx.Shared."__defn:lrn__"/2
  lib/axon_onnx/shared.ex:94: AxonOnnx.Shared."__defn:numpy_matmul__"/3

warning: Nx.power/2 is deprecated. Use pow/2 instead
  lib/axon_onnx/deserialize.ex: AxonOnnx.Deserialize.recur_nodes/2

warning: Nx.random_normal/4 is deprecated. Use Nx.Random instead
Invalid call found at 4 locations:
  lib/axon_onnx/deserialize.ex:1978: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:2009: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:2015: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:2022: AxonOnnx.Deserialize.recur_nodes/2

warning: Nx.random_uniform/4 is deprecated. Use Nx.Random.uniform/2 instead
Invalid call found at 4 locations:
  lib/axon_onnx/deserialize.ex:1911: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:1942: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:1948: AxonOnnx.Deserialize.recur_nodes/2
  lib/axon_onnx/deserialize.ex:1955: AxonOnnx.Deserialize.recur_nodes/2

Generated axon_onnx app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> cc_precompiler
Compiling 3 files (.ex)
Generated cc_precompiler app
==> stb_image
Compiling 2 files (.ex)
Generated stb_image app
==> dll_loader_helper
Compiling 2 files (.ex)
Generated dll_loader_helper app
==> torchx
/bin/sh: cmake: command not found
make: *** [/Users/charlie/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.4/e370518b40bcce31b1ed7b41bb1c7f20/_build/dev/lib/torchx/priv/torchx.so] Error 127
could not compile dependency :torchx, "mix compile" failed. Errors may have been logged above. You can recompile this dependency with "mix deps.compile torchx", update it with "mix deps.update torchx" or clean it with "mix deps.clean torchx"

Converting An ONNX Model Into Axon

Using models that have been writing in other languages and machine learning frameworks is a capability that the ML community has been working towards for some time. This requires a common format for models to be converted to. This problem is solved by ONNX. This allows SW engineers, data scientists and ML engineers to share models and use them in their preferred language/framework

ONNX enables Elixir developers to use thousands (maybe millions some day) of public models from repositories such as HuggingFace directly in Axon. It would not economic sense for Elixir developers to take the time to translate models that are written in PyTorch or Tensorflow unless their is a very strong case for this translation to occur. Most developers will want to use Axon plus an existing model to accomplish a business use case for their product/team

This notebook is developed on a Macbook with an M1 chip. At this moment neither the PyTorch backend nor the XLA backend are supported. This means to run this in a suitable environment for running machine learning tasks, you would deploy this Livebook to a hosting solution that provides the hardware to support running these backends (like HuggingFace spaces). To do this you need to configure your environment:

Mix.install([
  {:nx, "~> 0.5.2"},
  {:axon, "~> 0.5.1"},
  {:axon_onnx, "~> 0.4.0"},
  {:exla, "~> 0.5.2"},
  {:stb_image, "~> 0.6.1"},
  {:kino, "~> 0.9.0"}
],
  # change to "cuda111" for Nvidia GPU
  system_env: %{"XLA_TARGET" => xla_target}
)


Nx.default_backend(EXLA.Backend)

Here is an opinionated module presenting a simple API for loading in an ONNX file and saving the converted Axon model in the provided directory. This API allows us to save multiple models pretty quickly:

defmodule OnnxToAxon do
  def onnx_axon(path_to_onnx_file, path_to_axon_dir) do
    axon_name = axon_name_from_onnx_path(path_to_onnx_file)
    path_to_axon = Path.join(path_to_axon_dir, axon_name)

    {model, parameters} = AxonOnnx.import(path_to_onnx_file)
    model_bytes = Axon.serialize(model, parameters)
    File.write!(path_to_axon, model_bytes)
  end

  defp axon_name_from_onnx_path(onnx_path) do
    model_root = onnx_path |> Path.basename() |> Path.rootname()
    "#{model_root}.axon"
  end
end
{:module, OnnxToAxon, <<70, 79, 82, 49, 0, 0, 9, ...>>, {:axon_name_from_onnx_path, 1}}

ONNX Model

path_to_onnx_file = "/Users/charlie/ML/models/onnx/cats_v_dogs.onnx"
"/Users/charlie/ML/models/onnx/cats_v_dogs.onnx"
path_to_axon_dir = "/Users/charlie/ML/models/axon"
"/Users/charlie/ML/models/axon"
# Convert an ONNX model into Axon
OnnxToAxon.onnx_axon(path_to_onnx_file, path_to_axon_dir)

16:10:44.144 [warn] Attempting to serialize an Axon model. Serialization is discouraged and will be deprecated, then removed in future releases. You should keep your model definitions as code and serialize your parameters using `Nx.serialize/2`.
:ok

Inference On ONNX Derived Models

Load the Axon model

cats_v_dogs = File.read!("#{path_to_axon_dir}/cats_v_dogs.axon")
{cats_v_dogs_model, cats_v_dogs_params} = Axon.deserialize(cats_v_dogs)

16:12:07.574 [warn] Attempting to deserialize a serialized Axon model. Deserialization is discouraged and will be deprecated, then removed in future releases. You should keep your model definitions as code and serialize your parameters using `Nx.serialize/2`.
{#Axon<
   inputs: %{"input" => {10, 3, 224, 224}}
   outputs: "output"
   nodes: 65
 >,
 %{
   "input.104" => %{
     "bias" => #Nx.Tensor<
       f32[256]
       [-0.10255216807126999, 0.01242440938949585, -0.31021589040756226, 0.009901568293571472, 0.3029698431491852, 0.20485056936740875, -0.06305588781833649, 0.5202410221099854, 0.10102447867393494, 0.21567291021347046, -0.06667517125606537, 0.06317032873630524, 0.07280120253562927, 0.19022129476070404, 0.1731688678264618, -0.035899221897125244, 0.23101085424423218, 0.05788235366344452, 0.055577352643013, 0.26126232743263245, 0.0849180519580841, 0.23112308979034424, 0.02616637945175171, 0.005282506346702576, 0.07253815233707428, -0.020964205265045166, 2.6334822177886963e-4, -0.0638253390789032, 0.32014960050582886, 0.4504479467868805, 0.006706371903419495, 0.07622749358415604, 0.2972850799560547, -0.11482954770326614, 0.06434966623783112, 0.27411627769470215, 0.19332990050315857, 0.2405610978603363, -0.15616801381111145, 0.188066303730011, 0.6546944975852966, 0.12711232900619507, -0.01894807070493698, 0.06539285182952881, -0.6546027660369873, 0.37661534547805786, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[256][256][3][3]
       [
         [
           [
             [0.02980261668562889, 0.03022027388215065, 0.024108879268169403],
             [0.030693294480443, 0.03317975625395775, 0.03376408666372299],
             [0.01578144170343876, 0.008534814231097698, 0.01516900397837162]
           ],
           [
             [-0.0013935305178165436, -0.0011014307383447886, 0.0036392314359545708],
             [3.8647252949886024e-5, 0.0020625314209610224, 0.01162804290652275],
             [-0.011805879883468151, -0.012488552369177341, -0.0015498219290748239]
           ],
           [
             [-0.008269588463008404, -0.013258696533739567, -0.014901615679264069],
             [-0.00891254935413599, -0.014860459603369236, -0.015747573226690292],
             [0.0054487125016748905, 0.003112942446023226, -0.004765582270920277]
           ],
           [
             [-0.010309847071766853, -0.011159747838973999, -0.017865899950265884],
             [-0.002579166553914547, -0.002923264866694808, -0.01028983574360609],
             [-0.0016319465357810259, -5.981922731734812e-4, 8.313407306559384e-4]
           ],
           [
             [-0.006339760031551123, 0.0013753932435065508, 0.004977775737643242],
             [2.6102454285137355e-4, 9.971223771572113e-4, 0.0019811317324638367],
             [0.0030383302364498377, 0.005810011178255081, -0.0023707642685621977]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.120" => %{
     "bias" => #Nx.Tensor<
       f32[512]
       [-0.11012329161167145, -0.1334468424320221, -0.13490815460681915, 0.14369918406009674, 0.12017859518527985, 0.005583226680755615, 0.1439262330532074, -0.17790907621383667, 0.2666502892971039, 0.36823439598083496, -0.039247065782547, 0.12245479226112366, 0.10533519089221954, -0.13710498809814453, 0.10539458692073822, 0.12219296395778656, 0.041899293661117554, 0.062102049589157104, -0.08650007098913193, 0.0017465651035308838, 0.10680529475212097, 0.057311177253723145, 0.4606086015701294, 0.09543457627296448, 0.03240486979484558, 0.06992611289024353, -0.26842135190963745, 0.12656061351299286, 0.06620584428310394, 0.10129502415657043, 0.10617434978485107, 0.17909353971481323, 0.08288732171058655, 0.002831697463989258, -0.11379002034664154, 0.10889385640621185, 0.4372505247592926, -0.17590320110321045, 0.23379480838775635, -0.00852176547050476, 0.22691969573497772, 0.4407877027988434, 0.1594594419002533, 0.23647919297218323, 0.011865854263305664, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[512][256][3][3]
       [
         [
           [
             [-0.008114181458950043, -0.013126172125339508, -0.015620616264641285],
             [0.014821208082139492, 0.017004260793328285, 0.019961778074502945],
             [0.03015279397368431, 0.023433180525898933, 0.034742772579193115]
           ],
           [
             [-0.00544447498396039, 0.0025770326610654593, -0.0035608436446636915],
             [0.004550183191895485, 0.010284008458256721, -0.0020750206895172596],
             [-0.003910803701728582, 0.009475957602262497, 0.001209274516440928]
           ],
           [
             [0.012183013372123241, -0.007880083285272121, -0.010954849421977997],
             [0.005971798673272133, -7.429560646414757e-4, -0.00873740203678608],
             [-0.00375011982396245, -0.019148051738739014, -0.006404926534742117]
           ],
           [
             [-0.009763436391949654, -0.006884351838380098, -0.011995279230177402],
             [0.013272769749164581, 0.012995783239603043, -5.930152256041765e-4],
             [0.02203487791121006, 0.0375659316778183, 0.020397966727614403]
           ],
           [
             [0.009320005774497986, 0.00815019104629755, 0.006229656282812357],
             [0.007680999580770731, 0.0043557011522352695, -0.006194974761456251],
             [0.01412269938737154, 0.0018995723221451044, ...]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.140" => %{
     "bias" => #Nx.Tensor<
       f32[512]
       [0.2701534628868103, 0.3412540853023529, 0.059474095702171326, 0.19265660643577576, 0.34780120849609375, 0.15283502638339996, 0.18487772345542908, 0.23036494851112366, 0.082991823554039, 0.21394386887550354, 0.14001530408859253, 0.1661788374185562, 0.19085638225078583, 0.18932723999023438, 0.029606759548187256, 0.08522605895996094, 0.18700945377349854, 0.1641554832458496, 0.24583907425403595, 0.10366123914718628, 0.1369037628173828, 0.032656505703926086, 0.19965916872024536, 0.18968966603279114, 0.5044026970863342, 0.2143799215555191, 0.1118353009223938, 0.16536200046539307, 0.1161937415599823, 0.16214141249656677, 0.196969673037529, 0.29720214009284973, 0.3085396885871887, 0.2721194922924042, 0.23629915714263916, 0.2515372931957245, 0.21183982491493225, 0.23609426617622375, 0.21477048099040985, 0.13355383276939392, 0.05379834771156311, 0.2136279046535492, 0.22155271470546722, 0.14189766347408295, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[512][512][3][3]
       [
         [
           [
             [-0.004852789454162121, -0.0030833950731903315, 0.00693540507927537],
             [0.004548283759504557, -0.004725727252662182, 0.010022134520113468],
             [0.010843261145055294, 0.010541762225329876, 0.017730120569467545]
           ],
           [
             [-0.0014832474989816546, 0.008276263251900673, 0.00680690910667181],
             [-0.014394419267773628, -0.012837381102144718, -0.009244682267308235],
             [-0.008021551184356213, -0.0024558810982853174, -0.0027384727727621794]
           ],
           [
             [-0.008315932005643845, -0.009325191378593445, 9.747524745762348e-4],
             [-0.01343428622931242, -0.02077828161418438, -0.009686430916190147],
             [-0.010126739740371704, -0.020081182941794395, -0.017858095467090607]
           ],
           [
             [0.004692510701715946, -0.005120026413351297, 0.004012784920632839],
             [-0.0082739582285285, -0.03099539503455162, -0.002721106866374612],
             [0.003731094067916274, -0.012846999801695347, 0.001240275101736188]
           ],
           [
             [0.0012077062856405973, 0.007579203229397535, 0.00910201109945774],
             [-4.6364564332179725e-4, -0.005896805785596371, 0.006524364463984966],
             [-0.014250709675252438, ...]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.16" => %{
     "bias" => #Nx.Tensor<
       f32[64]
       [0.39058536291122437, 0.3297346830368042, 0.16846027970314026, 0.39849063754081726, 0.7723880410194397, 0.46614134311676025, 0.3427618145942688, 0.141500324010849, 0.5892147421836853, 0.311024934053421, 0.09674270451068878, 0.9240381717681885, 0.2005918323993683, 0.3342496156692505, 0.3639898896217346, 0.3643597662448883, 0.6803058981895447, 0.01658894121646881, 0.8311477899551392, 0.3224527835845947, 0.4626739025115967, 0.8104490041732788, 0.875454843044281, 0.38464146852493286, 0.10243505239486694, 0.1251845806837082, -0.08998915553092957, 0.8140488862991333, 0.010452806949615479, -0.2336074709892273, -0.07846075296401978, 0.9811797142028809, 0.5440393090248108, 1.2300128936767578, 0.9277883172035217, 0.6543285250663757, 0.4902767539024353, 0.10552435368299484, 1.2785937786102295, 0.92914217710495, 0.2011302411556244, 0.6319721341133118, -0.18300236761569977, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[64][64][3][3]
       [
         [
           [
             [0.028264105319976807, -0.04678752273321152, -0.01000113133341074],
             [-0.03653400018811226, -0.39260753989219666, -0.1045677587389946],
             [0.03224056214094162, -0.04744917154312134, -0.005983613431453705]
           ],
           [
             [-0.00354730780236423, 0.006985054817050695, 2.3401706130243838e-4],
             [0.0201820507645607, -0.07926938682794571, -0.011448322795331478],
             [0.0015579529572278261, 0.0034659868106245995, 0.03514765202999115]
           ],
           [
             [-1.1605503225098346e-9, -1.9288943775563894e-8, -1.6195008001318456e-8],
             [1.0677024597782747e-8, 4.091558647445481e-9, 6.160818077916019e-9],
             [5.5904547657803505e-9, 4.327163516393284e-9, 7.616167252422201e-9]
           ],
           [
             [5.147525225766003e-4, -0.002576477825641632, 0.0033656528685241938],
             [-0.010653612203896046, -0.03368113562464714, -0.007842055521905422],
             [0.012534352950751781, -0.007321042474359274, -0.0027268771082162857]
           ],
           [
             [-2.4242039486921385e-9, -2.617766670098831e-9, 1.4169434514599288e-9],
             [-1.4787721047682112e-9, -7.28144822215171e-10, -2.442494650978233e-9],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.32" => %{
     "bias" => #Nx.Tensor<
       f32[64]
       [0.20088541507720947, -0.9981099367141724, 0.5812627673149109, 0.12833914160728455, 0.28994590044021606, 0.8534805774688721, 0.19867593050003052, 0.021073587238788605, 0.25767847895622253, 0.3452529013156891, 0.3978758156299591, 0.223368838429451, 0.8649293780326843, 0.3268284797668457, -0.11873866617679596, -0.11314578354358673, 0.2288074791431427, 0.07647715508937836, 0.7165235280990601, 0.217575341463089, -0.23578426241874695, 0.8018866777420044, 0.7439314723014832, -0.05131138861179352, -0.18318139016628265, 0.7103950381278992, 0.8330986499786377, 0.46872079372406006, -0.4479110836982727, 0.04658251255750656, -0.4482766389846802, 1.1971343755722046, 0.46852537989616394, 0.0016401708126068115, 0.8745933175086975, 1.214421272277832, -0.013733558356761932, -0.08480067551136017, 0.34241291880607605, 0.6692671775817871, -0.8205735683441162, 0.12076211720705032, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[64][64][3][3]
       [
         [
           [
             [0.009143853560090065, -0.002450926462188363, -0.0017525143921375275],
             [-0.00913943350315094, -0.005763588473200798, -0.016344279050827026],
             [0.023512117564678192, 0.035115715116262436, 0.020116902887821198]
           ],
           [
             [0.006506951991468668, -0.004044913221150637, -0.004904876463115215],
             [-0.01975400373339653, -0.011037465184926987, -0.025265291333198547],
             [-7.050625281408429e-4, 0.021680807694792747, 0.02368488349020481]
           ],
           [
             [0.010004709474742413, 0.0018492299132049084, -0.004645614419132471],
             [-0.004279324784874916, -0.011845993809401989, -0.01324962917715311],
             [-0.01397590059787035, -0.02241162583231926, -0.012133411131799221]
           ],
           [
             [-0.002640221733599901, -0.021292105317115784, 0.03215683996677399],
             [0.015442206524312496, -0.006780137773603201, 0.02653222531080246],
             [0.01498928852379322, 0.015813639387488365, -0.027547387406229973]
           ],
           [
             [0.029619459062814713, 0.036391064524650574, 0.03163759782910347],
             [0.004423873033374548, 0.007228051777929068, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.4" => %{
     "bias" => #Nx.Tensor<
       f32[64]
       [0.2317274659872055, 0.2523372173309326, -1.05430115127092e-6, -0.666293203830719, -1.657055648252026e-8, 0.15954801440238953, 0.46721315383911133, -4.3019585405090766e-7, 0.3005814552307129, -8.004371920833364e-6, 0.3527475595474243, 0.31285855174064636, -0.2601037919521332, -3.478067446849309e-5, 0.10771488398313522, 0.22187425196170807, 0.38182276487350464, -0.5319052338600159, -0.6140309572219849, 0.5665425062179565, 0.2980884313583374, 0.5875009298324585, 0.47954607009887695, 0.32620003819465637, 0.19730432331562042, 0.1967487633228302, 0.15240077674388885, 0.10463902354240417, 0.49047064781188965, 0.009381033480167389, 0.16744095087051392, 0.33219435811042786, 0.26775941252708435, 0.45020750164985657, -0.2835230827331543, -0.03874418884515762, -2.450687759392167e-7, 0.32660406827926636, -4.9151783088063894e-8, 0.23783636093139648, 0.2338869571685791, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[64][3][7][7]
       [
         [
           [
             [-0.0028019791934639215, -0.0016413616249337792, -5.010636523365974e-4, 0.01972866989672184, 0.014920719899237156, 0.0044916123151779175, -0.0033793486654758453],
             [0.002864276757463813, 0.0024858538527041674, -0.02905162051320076, -0.07409906387329102, -0.07165075093507767, -0.034102387726306915, 9.688315913081169e-4],
             [-0.0019021540647372603, 0.015556327998638153, 0.07798361778259277, 0.15499864518642426, 0.13717548549175262, 0.06764435023069382, 0.016761379316449165],
             [0.007990336045622826, -0.017731331288814545, -0.07881280779838562, -0.11585383862257004, -0.07155387848615646, -1.9549347052816302e-4, 0.01518225483596325],
             [-0.007339056581258774, 0.004195648245513439, 0.019127463921904564, -0.014328746125102043, -0.0879356861114502, -0.11108877509832382, -0.0680975466966629],
             [0.008010330609977245, 0.010762837715446949, 0.01653868891298771, 0.0630345270037651, 0.10920489579439163, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.48" => %{
     "bias" => #Nx.Tensor<
       f32[128]
       [-0.12085895240306854, -0.3294113278388977, 0.019299542531371117, 0.018438413739204407, 0.12564125657081604, 0.23528410494327545, 0.030958974733948708, -0.04151897877454758, 0.058561503887176514, -0.1585184931755066, -0.19115597009658813, 0.20707836747169495, 0.0061426833271980286, 0.26531243324279785, 0.09738633036613464, -0.04977022111415863, -0.14273084700107574, 0.7192244529724121, -0.19620242714881897, 0.20352962613105774, 0.1065758615732193, 0.055553629994392395, -0.29206347465515137, 0.18743343651294708, -0.0920708179473877, 0.15387415885925293, -0.004067450761795044, -0.020898494869470596, 0.03300565481185913, -0.13152343034744263, -0.10637379437685013, -0.16178111732006073, 0.5594324469566345, -0.13452766835689545, -0.1640821099281311, 0.12561795115470886, -0.369643896818161, -0.04847084358334541, 0.16257557272911072, 0.16360332071781158, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[128][64][3][3]
       [
         [
           [
             [-0.028565170243382454, -0.04407990723848343, -0.05485813319683075],
             [0.02818482369184494, -0.005938543472439051, -0.04022299125790596],
             [0.04765123873949051, 0.034852493554353714, -0.003328080056235194]
           ],
           [
             [-0.009580712765455246, -0.0025346672628074884, 9.486497147008777e-4],
             [0.002450849860906601, 0.007511203642934561, 0.010084796696901321],
             [9.804379660636187e-4, -0.0015770676545798779, -0.004777619615197182]
           ],
           [
             [0.0024091172963380814, 0.0026440375950187445, -0.004750310443341732],
             [0.0031505757942795753, -0.005319045856595039, -0.016934450715780258],
             [0.006929904222488403, -0.005520435515791178, -0.01874990016222]
           ],
           [
             [-0.015278496779501438, -0.014780244790017605, 0.00415421137586236],
             [-0.006734006572514772, -0.019458729773759842, -0.0037485812790691853],
             [0.012252654880285263, 0.002779492875561118, 2.4828806635923684e-4]
           ],
           [
             [0.021199271082878113, 0.021622469648718834, 0.011315584182739258],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.68" => %{
     "bias" => #Nx.Tensor<
       f32[128]
       [0.05796706676483154, 0.1304682493209839, -0.48415225744247437, 0.2790919542312622, 0.32541871070861816, 0.23160740733146667, 0.128997340798378, 0.29403066635131836, -0.04000064730644226, 0.03472676873207092, 0.08677017688751221, -0.029069632291793823, -0.029288090765476227, 0.10490615665912628, 0.03905832767486572, 0.29994964599609375, -0.06109467148780823, 0.05183476209640503, -0.29549524188041687, -0.022138699889183044, -0.06486694514751434, 0.24632392823696136, 0.03461739420890808, -0.07615165412425995, 0.39280182123184204, -0.12164877355098724, 0.143124058842659, -0.1789160817861557, 0.12068241834640503, 0.01911863684654236, 0.16610179841518402, 0.3491804003715515, -0.3434971570968628, -0.11597320437431335, -0.1324601173400879, 0.14369553327560425, 0.28282684087753296, -0.031145095825195312, 0.007936030626296997, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[128][128][3][3]
       [
         [
           [
             [-5.262177437543869e-4, -0.00489784637466073, -0.005039934068918228],
             [0.0160769522190094, 0.001257465686649084, 0.0037846278864890337],
             [0.006471970118582249, -0.010600507259368896, 0.002448693849146366]
           ],
           [
             [-0.013703188858926296, -7.995142368599772e-4, -0.009209748357534409],
             [-0.010768691077828407, 0.006196790840476751, 0.004454707261174917],
             [0.0072552855126559734, -0.004512551240622997, -0.0072962394915521145]
           ],
           [
             [-6.745486753061414e-4, 0.019889498129487038, 0.014622725546360016],
             [-0.03940449655056, -0.015210512094199657, 0.0015319481026381254],
             [0.026304522529244423, -0.06687644124031067, -0.03881275653839111]
           ],
           [
             [0.0030681914649903774, 0.006286886055022478, 0.008150829933583736],
             [0.006075439043343067, 0.0066594285890460014, 0.012328804470598698],
             [-0.002715378999710083, -0.001705002156086266, 0.006692140828818083]
           ],
           [
             [-0.0031075719743967056, -0.02324421890079975, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "input.84" => %{
     "bias" => #Nx.Tensor<
       f32[256]
       [-8.177012205123901e-4, 0.0967610627412796, 0.1719234585762024, 0.03387782350182533, 0.24569299817085266, -0.1446833312511444, -0.11149755120277405, -0.11916022002696991, -0.14267459511756897, 0.058398544788360596, 0.2738577723503113, 0.20987257361412048, 0.5974380373954773, 0.2077210247516632, 0.5942175388336182, 0.21452926099300385, -0.3098541498184204, 0.3629045784473419, -0.28036072850227356, 0.10152695327997208, 0.08043119311332703, 0.36084413528442383, -0.17492465674877167, -0.07436101138591766, 0.2582421898841858, 0.4460412561893463, 0.20715858042240143, -0.46175655722618103, 0.04500553756952286, -0.20734336972236633, 0.3693496584892273, 0.22970852255821228, -0.36248117685317993, 0.026639118790626526, -0.044602327048778534, 0.29927152395248413, 0.2084745168685913, -0.05406811833381653, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[256][128][3][3]
       [
         [
           [
             [-0.009433602914214134, -0.009474102407693863, -0.009162158705294132],
             [-0.0030906018801033497, 0.0092063769698143, 0.005929871927946806],
             [-0.008593972772359848, -9.349756146548316e-5, -0.006583182141184807]
           ],
           [
             [-0.005912473425269127, -0.010336859151721, -0.005291664972901344],
             [0.006110117305070162, -0.005567251704633236, -0.003635238856077194],
             [0.0036871249321848154, -0.007904792204499245, -0.004320142325013876]
           ],
           [
             [-0.012245627120137215, -0.020721949636936188, -0.01886877417564392],
             [-0.01609964855015278, -0.03088284842669964, -0.021763723343610764],
             [-0.018368728458881378, -0.02057608775794506, -0.009225909598171711]
           ],
           [
             [-0.009139428846538067, 3.3259994233958423e-4, -0.003555502276867628],
             [-0.015805216506123543, -0.005032840650528669, -0.011540019884705544],
             [-0.006358750630170107, -0.012539410032331944, -0.007548004854470491]
           ],
           [
             [-0.01016774121671915, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_223" => %{
     "bias" => #Nx.Tensor<
       f32[64]
       [0.11700877547264099, -0.0016212984919548035, -1.2389702796936035, -0.3228069841861725, 0.0733313262462616, -0.13630887866020203, 0.02991156093776226, 0.01335197128355503, 0.9604557752609253, -0.056087426841259, 0.06572458148002625, 0.4833105206489563, 0.09255271404981613, -0.16148844361305237, -0.0062994882464408875, 0.13380266726016998, -0.0470370277762413, -0.15777486562728882, -0.27034610509872437, -0.3997953534126282, -0.5465368032455444, 0.10369472205638885, -0.3253932297229767, 0.12614542245864868, -0.03524066507816315, 0.6200454235076904, -0.03369453176856041, -0.11300484836101532, 0.2811250388622284, 0.2191983163356781, 0.483904093503952, 0.2681454122066498, -0.015199542045593262, 0.33972620964050293, 0.11541686952114105, -0.046400900930166245, -0.8086198568344116, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[64][64][3][3]
       [
         [
           [
             [0.01816454902291298, -0.07294163852930069, -0.0033414787612855434],
             [-0.05997346341609955, -0.23021353781223297, -0.0716257244348526],
             [-0.039963651448488235, -0.13307137787342072, -0.03819657117128372]
           ],
           [
             [-0.011564474552869797, 0.015219778753817081, -0.0013934375019744039],
             [-0.0020580885466188192, 0.031693849712610245, -0.00786368828266859],
             [-0.041423771530389786, 0.004704250954091549, 0.0021779616363346577]
           ],
           [
             [-0.00327970995567739, 0.013627355918288231, 0.009181257337331772],
             [0.009025373496115208, -0.005510837305337191, 0.008267567493021488],
             [0.009360985830426216, 0.00772381154820323, 0.00387995271012187]
           ],
           [
             [0.02269722707569599, -0.046998053789138794, -0.03570636734366417],
             [0.009844856336712837, -0.11309101432561874, -0.05261484533548355],
             [-0.0029174077790230513, -0.07110293209552765, -0.0348593071103096]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_229" => %{
     "bias" => #Nx.Tensor<
       f32[64]
       [-0.10703185945749283, 0.22948969900608063, -1.1757307052612305, 0.004097491502761841, -0.09718765318393707, -0.0441732257604599, -0.31656980514526367, -0.24595341086387634, 0.6863982081413269, 0.38147178292274475, -0.23885448276996613, 0.006021881476044655, -0.14258751273155212, 1.1101441383361816, -0.23581522703170776, 0.07058548182249069, 0.09472654014825821, 0.24212470650672913, -0.23426881432533264, -0.012344378978013992, -0.03343100845813751, -0.147905632853508, -0.04462171345949173, -0.061329133808612823, -0.23938453197479248, 0.10398685932159424, 0.07756980508565903, 0.20166830718517303, 1.0389947891235352, 0.0424036830663681, 0.09520983695983887, 0.19708308577537537, 0.7442616820335388, -0.04239851236343384, 0.6960866451263428, -0.1580812931060791, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[64][64][3][3]
       [
         [
           [
             [-0.02467520907521248, -0.005265281535685062, 0.0052703446708619595],
             [-0.009432479739189148, 0.04777374863624573, 0.026454273611307144],
             [-0.010470099747180939, 0.06531570106744766, 0.0340411551296711]
           ],
           [
             [0.06706727296113968, 0.0487847775220871, 0.05116129666566849],
             [0.025383656844496727, -0.014870687387883663, 0.008729268796741962],
             [0.051719244569540024, 0.034963689744472504, 0.043148841708898544]
           ],
           [
             [-0.017472876235842705, -0.044356465339660645, -0.05270271748304367],
             [-0.02634846419095993, 0.03228754922747612, 0.005401493515819311],
             [-0.023021327331662178, 0.015869393944740295, 0.02977365255355835]
           ],
           [
             [0.0015554878627881408, -0.09528835117816925, -0.05262262001633644],
             [-0.05457450821995735, -0.17271561920642853, -0.12160201370716095],
             [7.84864358138293e-4, -0.08539796620607376, ...]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_235" => %{
     "bias" => #Nx.Tensor<
       f32[128]
       [0.3241359293460846, 0.27947962284088135, 0.6030840277671814, 0.009196937084197998, 0.3452882766723633, 0.10920313000679016, 0.9687239527702332, 0.22556713223457336, 0.4261491894721985, 0.11880181729793549, 0.03846334293484688, 0.16517306864261627, -0.07173064351081848, 0.1811138093471527, 0.1614181399345398, 0.4611174464225769, 0.5462740659713745, 0.11766283214092255, -0.5408067107200623, -0.04210607334971428, 0.3623342514038086, 0.127481147646904, 0.31506994366645813, 0.046513281762599945, 0.043591104447841644, 0.3112694025039673, 0.2124355584383011, -0.03628380596637726, -0.06706595420837402, 0.4333105981349945, 0.3378453254699707, 0.21614670753479004, 0.3788752257823944, 0.13142448663711548, 0.4652714431285858, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[128][128][3][3]
       [
         [
           [
             [-0.004815979395061731, -0.006498327478766441, 0.0018362372647970915],
             [-0.006985980551689863, 0.017192905768752098, 0.030052555724978447],
             [-0.017971111461520195, 0.0035869635175913572, 0.008648558519780636]
           ],
           [
             [0.02352527529001236, 0.01662944257259369, 0.005346672143787146],
             [0.002713234396651387, 0.013244196772575378, 0.01815827563405037],
             [0.02344762161374092, 0.023479484021663666, 0.008839689195156097]
           ],
           [
             [0.012555832974612713, -0.014186849817633629, -0.02388766035437584],
             [-0.013072578236460686, -0.012629187665879726, -0.039460763335227966],
             [-0.014414372853934765, -0.005027364939451218, 0.002633575815707445]
           ],
           [
             [0.004963915329426527, 0.0029304982163012028, 0.009135729633271694],
             [-0.01190568134188652, 0.005517667159438133, 0.003892903681844473],
             [-0.006154944654554129, ...]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_238" => %{
     "bias" => #Nx.Tensor<
       f32[128]
       [0.18165035545825958, -6.747543811798096e-4, 0.1338605284690857, -0.21819636225700378, -0.011696472764015198, 0.06435933709144592, -0.28644225001335144, 0.1931188851594925, -0.08892102539539337, 0.16471855342388153, 0.2442406564950943, 0.2823583483695984, 0.33865126967430115, 0.12257973849773407, -0.0016330033540725708, 0.19978636503219604, 0.09226927161216736, -0.19506211578845978, -0.08055393397808075, 0.00711375568062067, 0.1217801421880722, -0.020144857466220856, 0.6818500757217407, 0.02597694844007492, 0.03954639285802841, -0.03921331465244293, 0.23783588409423828, 0.190708726644516, -0.151023730635643, 0.8105874061584473, 0.07611434161663055, 0.08958728611469269, 0.9507854580879211, 0.22652959823608398, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[128][64][1][1]
       [
         [
           [
             [0.012082752771675587]
           ],
           [
             [-0.23409134149551392]
           ],
           [
             [0.009416989982128143]
           ],
           [
             [0.006355063058435917]
           ],
           [
             [-0.021120594814419746]
           ],
           [
             [0.02530841715633869]
           ],
           [
             [0.11245305091142654]
           ],
           [
             [0.0010163293918594718]
           ],
           [
             [-0.011756040155887604]
           ],
           [
             [0.08959800750017166]
           ],
           [
             [-0.0038925453554838896]
           ],
           [
             [0.00902885477989912]
           ],
           [
             [0.014491967856884003]
           ],
           [
             [0.026994410902261734]
           ],
           [
             [0.04159475862979889]
           ],
           [
             [0.006760460324585438]
           ],
           [
             [0.10022340714931488]
           ],
           [
             [0.014805061742663383]
           ],
           [
             [-0.11315683275461197]
           ],
           [
             [-0.03744495287537575]
           ],
           [
             [0.005226226057857275]
           ],
           [
             [0.05470598489046097]
           ],
           [
             [0.0027231054846197367]
           ],
           [
             [0.12101384252309799]
           ],
           [
             [-0.003309142543002963]
           ],
           [
             [0.012003741227090359]
           ],
           [
             [0.011226937174797058]
           ],
           [
             [-0.22858673334121704]
           ],
           [
             [0.052908755838871]
           ],
           [
             [0.021630311384797096]
           ],
           [
             [0.014828636310994625]
           ],
           [
             [0.19380220770835876]
           ],
           [
             [-0.07843420654535294]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_244" => %{
     "bias" => #Nx.Tensor<
       f32[128]
       [-0.11422383040189743, -0.13024669885635376, -0.4605652987957001, -0.0677630752325058, 0.19903643429279327, -0.16608868539333344, 0.03557538986206055, 0.007463514804840088, 0.15729649364948273, -0.06694583594799042, -0.16176018118858337, -0.1355186551809311, -0.06562940031290054, 0.23673200607299805, -0.20291008055210114, -0.21557679772377014, -0.14289376139640808, -0.06948217749595642, -0.12922103703022003, 0.10408642888069153, -0.08025123178958893, -0.09908505529165268, -0.4549119472503662, 0.04387699067592621, -0.4576791524887085, 0.08881376683712006, -0.1485273689031601, 0.04479994624853134, -0.1848645955324173, -0.24494987726211548, 0.08287344872951508, -1.187548041343689e-4, -0.14890624582767487, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[128][128][3][3]
       [
         [
           [
             [-0.013609498739242554, 0.004379323683679104, -7.50118400901556e-4],
             [-0.007403417956084013, -0.016361938789486885, -0.020524226129055023],
             [0.005440857261419296, 0.008654438890516758, -0.011772758327424526]
           ],
           [
             [-0.00934152863919735, 0.0022971653379499912, 0.02006314881145954],
             [-0.01485093217343092, 0.016935527324676514, 0.05441441759467125],
             [0.005219102371484041, 0.06473482400178909, 0.08654006570577621]
           ],
           [
             [-0.006830218713730574, 0.005938380025327206, 0.021628646180033684],
             [-0.006811949890106916, 0.008022354915738106, 0.007146121002733707],
             [0.006655591540038586, 0.010422145947813988, 0.006165296770632267]
           ],
           [
             [0.0025338211562484503, -0.008463718928396702, 0.009631446562707424],
             [0.014575249515473843, -0.0034092748537659645, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_250" => %{
     "bias" => #Nx.Tensor<
       f32[256]
       [0.1761150062084198, 0.39926907420158386, 0.028157640248537064, -0.12261997163295746, 0.3145957589149475, 0.12004008889198303, 0.14092761278152466, 0.19924694299697876, 0.06612507253885269, 0.3014739453792572, 0.019873440265655518, 0.1952602118253708, 0.07448466122150421, 0.2927410900592804, 0.13069085776805878, -0.01366419717669487, 0.26465392112731934, -0.06156037747859955, 0.23040230572223663, 0.19039861857891083, 0.18763481080532074, 0.19512218236923218, 0.21992316842079163, 0.0853813886642456, 0.27244725823402405, -0.020649690181016922, 0.08896717429161072, 0.36672279238700867, 0.474377304315567, 0.025128506124019623, 0.2171832174062729, 0.19263955950737, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[256][256][3][3]
       [
         [
           [
             [-0.007905409671366215, -0.03158345818519592, -0.012421931140124798],
             [-0.02217613346874714, -0.07352714240550995, -0.04610810801386833],
             [-0.03931130841374397, -0.07363095134496689, -0.05946849286556244]
           ],
           [
             [-0.02569691836833954, 2.937749377451837e-4, -0.025232335552573204],
             [0.009392653591930866, 0.010325574316084385, -0.015850285068154335],
             [-0.008665809407830238, 0.003035004250705242, -0.010638286359608173]
           ],
           [
             [-0.013904825784265995, 0.02131189964711666, 0.0030691048596054316],
             [0.013467090204358101, 0.011554697528481483, -0.003430237527936697],
             [-0.01924220100045204, -0.014374679885804653, -0.014827674254775047]
           ],
           [
             [0.038365066051483154, 0.015062511898577213, -3.074728374485858e-5],
             [0.03567345440387726, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_253" => %{
     "bias" => #Nx.Tensor<
       f32[256]
       [0.03746028244495392, 0.15253399312496185, 0.01860833540558815, 0.14026623964309692, 0.040661536157131195, 0.05014963075518608, 0.013623617589473724, 0.19425176084041595, 0.026282358914613724, 0.02571592852473259, 0.058496396988630295, 0.006650492548942566, -0.0072492435574531555, 0.042366430163383484, -0.03262672573328018, 0.02036023512482643, 0.1306997835636139, 0.019062265753746033, -0.0644722655415535, 0.11889855563640594, 0.12074162065982819, -0.07424326986074448, 0.11369872093200684, -0.0034573227167129517, -0.018708739429712296, 0.09986792504787445, -0.1100538820028305, -0.11751233041286469, 0.018458224833011627, 0.03849996626377106, 0.013222794979810715, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[256][128][1][1]
       [
         [
           [
             [0.005829719826579094]
           ],
           [
             [-0.01145381573587656]
           ],
           [
             [-0.010477710515260696]
           ],
           [
             [0.0068201711401343346]
           ],
           [
             [-0.024594461545348167]
           ],
           [
             [-0.021350832656025887]
           ],
           [
             [-0.00776443537324667]
           ],
           [
             [0.019060691818594933]
           ],
           [
             [-0.03842966631054878]
           ],
           [
             [0.034842606633901596]
           ],
           [
             [-0.0059853363782167435]
           ],
           [
             [-0.0015779553214088082]
           ],
           [
             [-0.012468256987631321]
           ],
           [
             [-0.0011315977899357677]
           ],
           [
             [5.05295698530972e-4]
           ],
           [
             [-0.018445048481225967]
           ],
           [
             [0.006140608806163073]
           ],
           [
             [0.007622201461344957]
           ],
           [
             [-0.02006443776190281]
           ],
           [
             [-0.005497447215020657]
           ],
           [
             [0.00619297381490469]
           ],
           [
             [-8.262437768280506e-4]
           ],
           [
             [-0.040694985538721085]
           ],
           [
             [0.007896581664681435]
           ],
           [
             [-0.014212172478437424]
           ],
           [
             [-0.02121609076857567]
           ],
           [
             [-0.014214497990906239]
           ],
           [
             [-0.005996911786496639]
           ],
           [
             [-0.014336524531245232]
           ],
           [
             [-0.02432529255747795]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_259" => %{
     "bias" => #Nx.Tensor<
       f32[256]
       [0.07094134390354156, -0.15035247802734375, 0.06227637827396393, -0.16428369283676147, -0.09959621727466583, 0.043053410947322845, -0.09103527665138245, -0.07734604924917221, -0.12980104982852936, -0.10478585958480835, -0.04469982534646988, -0.13226377964019775, 0.061032649129629135, -0.13703691959381104, -0.11405056715011597, -0.18845991790294647, -0.1075410544872284, 0.09258771687746048, 0.2783028185367584, 0.03316599875688553, -0.10966856777667999, 0.9808527231216431, -0.028177879750728607, 0.1550987958908081, -0.08509790897369385, -0.207159623503685, 0.23364616930484772, 0.016637414693832397, -0.04072579741477966, 0.3584448993206024, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[256][256][3][3]
       [
         [
           [
             [-0.06132592633366585, -0.037631288170814514, -0.03136652335524559],
             [-0.026590103283524513, -0.013129682280123234, -0.012786894105374813],
             [-0.0076367175206542015, 0.03008956089615822, 0.018315639346837997]
           ],
           [
             [-0.0013258641120046377, -0.005209944676607847, -0.0058287689462304115],
             [0.0012730347225442529, 0.0205924641340971, 0.022014399990439415],
             [0.021948130801320076, 0.047904122620821, 0.07210175693035126]
           ],
           [
             [-0.004775743465870619, -0.0034271690528839827, 0.007579814177006483],
             [0.020255673676729202, 0.037051767110824585, 0.03843056783080101],
             [0.05860872194170952, 0.0745416060090065, 0.06986779719591141]
           ],
           [
             [0.013238975778222084, 0.01745939999818802, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_265" => %{
     "bias" => #Nx.Tensor<
       f32[512]
       [0.40755313634872437, 0.3319821357727051, 0.17957660555839539, 0.12300670146942139, 0.009577751159667969, 0.09629647433757782, 0.1291329562664032, 0.3427320122718811, 0.3218705952167511, 0.33580493927001953, 0.27390050888061523, 0.32407069206237793, 0.07290486991405487, 0.25875648856163025, 0.2390243411064148, 0.07725945115089417, 0.3376724421977997, 0.39922407269477844, 0.18773458898067474, 0.1342305690050125, 0.20712845027446747, 0.22410215437412262, 0.2661794424057007, 0.08534207940101624, 0.282318651676178, 0.41150587797164917, 0.3008878827095032, 0.2664359509944916, 0.13953012228012085, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[512][512][3][3]
       [
         [
           [
             [8.859830995788798e-5, -0.03906199336051941, -0.03706206753849983],
             [-0.029231246560811996, -0.08224329352378845, -0.09186898171901703],
             [0.06918302178382874, 0.04597153142094612, -0.0431375615298748]
           ],
           [
             [0.029286954551935196, 0.07850300520658493, 0.06079816818237305],
             [-0.027538800612092018, -0.07183054834604263, -0.03075205348432064],
             [-0.02088899351656437, -0.02169787883758545, -0.07956641912460327]
           ],
           [
             [-0.01493792049586773, -0.06454647332429886, 0.014568177983164787],
             [-0.031498562544584274, -0.02576454170048237, 0.013112889602780342],
             [-0.007586618419736624, 0.013772434554994106, 0.012050892226397991]
           ],
           [
             [-0.01611766219139099, ...],
             ...
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_268" => %{
     "bias" => #Nx.Tensor<
       f32[512]
       [-0.018965333700180054, -0.024715185165405273, -0.25872084498405457, -0.2675590515136719, -0.24086356163024902, -0.052715763449668884, 0.07360944151878357, -0.07235388457775116, -0.18274088203907013, -0.053331196308135986, -0.019401244819164276, -0.131577268242836, -0.2690041959285736, 0.04574929177761078, -0.09365198016166687, -0.01966966688632965, -0.15743330121040344, 0.35889920592308044, -0.018192023038864136, -0.09922407567501068, -0.17015980184078217, 0.16596762835979462, -0.21810759603977203, -0.032022833824157715, -0.2524210214614868, -0.21111689507961273, 0.13444575667381287, 0.005104571580886841, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[512][256][1][1]
       [
         [
           [
             [0.0064383745193481445]
           ],
           [
             [0.004077882971614599]
           ],
           [
             [0.02021564170718193]
           ],
           [
             [0.0074815526604652405]
           ],
           [
             [0.011978211812675]
           ],
           [
             [-0.01443507056683302]
           ],
           [
             [5.685305222868919e-4]
           ],
           [
             [-0.06652790307998657]
           ],
           [
             [-0.015590229071676731]
           ],
           [
             [-0.06727351248264313]
           ],
           [
             [0.009745856747031212]
           ],
           [
             [-0.03356105461716652]
           ],
           [
             [0.027014147490262985]
           ],
           [
             [-0.005242329090833664]
           ],
           [
             [0.0033388889860361814]
           ],
           [
             [-0.06587104499340057]
           ],
           [
             [0.07506178319454193]
           ],
           [
             [0.01637321524322033]
           ],
           [
             [-0.02287232130765915]
           ],
           [
             [0.038088902831077576]
           ],
           [
             [0.03745897114276886]
           ],
           [
             [-0.00243575987406075]
           ],
           [
             [-0.03880726546049118]
           ],
           [
             [0.04689910635352135]
           ],
           [
             [0.02028927020728588]
           ],
           [
             [0.028256075456738472]
           ],
           [
             [0.002175817731767893]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "onnx::Add_274" => %{
     "bias" => #Nx.Tensor<
       f32[512]
       [0.56229168176651, 0.6421835422515869, 1.0794943571090698, 0.8927594423294067, 0.7054401636123657, 0.7034736275672913, 0.5723349452018738, 0.6974419355392456, -0.08316245675086975, 0.4702988862991333, 0.4810026288032532, 0.4895618259906769, 1.005676507949829, -0.005102291703224182, 1.2158994674682617, 0.8985353708267212, 0.4169228672981262, 2.304893970489502, 0.21489477157592773, 0.5786733627319336, -0.2711006999015808, 0.9574836492538452, 0.5937850475311279, 0.819282054901123, 0.3589557111263275, -0.10509796440601349, 0.24478378891944885, ...]
     >,
     "kernel" => #Nx.Tensor<
       f32[512][512][3][3]
       [
         [
           [
             [0.01256233174353838, 0.0675450786948204, -0.03277582675218582],
             [-0.013346201740205288, 0.022245211526751518, -0.10776571929454803],
             [0.028993481770157814, 0.08994386345148087, -0.021521493792533875]
           ],
           [
             [-0.15701468288898468, -0.1851029247045517, -0.15006138384342194],
             [-0.0675821602344513, -0.06320945918560028, -0.04251612722873688],
             [-0.09874605387449265, -0.11068431288003922, -0.11542537808418274]
           ],
           [
             [-0.13964341580867767, -0.15133513510227203, -0.13275715708732605],
             [-0.14548303186893463, -0.15577726066112518, -0.14223594963550568],
             [-0.1410800963640213, -0.11203450709581375, ...]
           ],
           ...
         ],
         ...
       ]
     >
   },
   "output" => %{
     "kernel" => #Nx.Tensor<
       f32[512][2]
       [
         [0.02385139651596546, 0.13206183910369873],
         [-0.025229768827557564, -0.020289720967411995],
         [0.004341271240264177, -0.015273981727659702],
         [-0.03410986438393593, -0.00706594530493021],
         [0.0076062860898673534, -0.054196327924728394],
         [-0.026565847918391228, 0.03922131285071373],
         [0.035581961274147034, 0.030472010374069214],
         [-0.050863105803728104, 0.08771440386772156],
         [-0.05721820890903473, 0.015420638956129551],
         [0.0494314469397068, -0.042466070502996445],
         [0.08297151327133179, -0.06097378954291344],
         [0.05511021614074707, 0.02809895947575569],
         [0.023985465988516808, 0.04846307635307312],
         ...
       ]
     >
   },
   "ret" => %{
     "beta" => #Nx.Tensor<
       f32[1024]
       [0.02276429533958435, 0.0035531623288989067, -0.009528876282274723, 0.012452857568860054, -0.013315263204276562, 0.005103593692183495, -0.026911968365311623, -0.022401897236704826, -0.009875668212771416, -0.02863006293773651, 0.009720338508486748, 0.0240049846470356, 0.010473674163222313, 0.01085545215755701, 0.01869283616542816, -6.822053692303598e-4, 0.0010404232889413834, 0.019887443631887436, 0.0032704935874789953, 0.006654275115579367, 0.0028322204016149044, 0.013029149733483791, -0.01944638602435589, -0.03464600443840027, 0.017518293112516403, ...]
     >,
     "gamma" => #Nx.Tensor<
       f32[1024]
       [0.9832035899162292, 1.0122419595718384, 0.9789151549339294, 0.9967371821403503, 1.0040515661239624, 1.006706714630127, 0.9962093234062195, 1.004937767982483, 1.0138654708862305, 0.9895288348197937, 1.0114372968673706, 0.9923561811447144, 0.9850446581840515, 0.9976463913917542, 1.011671543121338, 1.0118069648742676, 0.9990651607513428, 0.9895028471946716, 1.0061311721801758, 1.0052860975265503, 0.9899915456771851, 0.9716607928276062, 1.013612151145935, 0.9790956974029541, ...]
     >,
     "mean" => #Nx.Tensor<
       f32[1024]
       [2.9728572368621826, 3.365450620651245, 3.1304736137390137, 3.306053876876831, 3.239210605621338, 3.1325502395629883, 3.42802357673645, 4.110240936279297, 3.020580768585205, 3.4392223358154297, 2.9270002841949463, 2.863265037536621, 2.9700276851654053, 3.247666358947754, 2.875291109085083, 3.0183677673339844, 3.3663153648376465, 4.336845397949219, 3.3378782272338867, 3.104308605194092, 2.9515864849090576, 3.708252429962158, 3.428995370864868, ...]
     >,
     "var" => #Nx.Tensor<
       f32[1024]
       [6.364206790924072, 5.38767671585083, 2.908144235610962, 3.4610931873321533, 9.129782676696777, 5.208425045013428, 5.667339324951172, 7.857580661773682, 3.5213165283203125, 7.485493183135986, 6.509019374847412, 5.7534050941467285, 6.643040657043457, 4.421581268310547, 3.6362311840057373, 4.655275821685791, 3.5555241107940674, 5.58651065826416, 4.338912010192871, 7.985411643981934, 6.697246074676514, 6.58963680267334, ...]
     >
   },
   "ret.3" => %{
     "kernel" => #Nx.Tensor<
       f32[1024][512]
       [
         [0.10695895552635193, 0.0266690906137228, -0.03277174010872841, 0.014272360131144524, -0.030955830588936806, -0.03370890021324158, 0.0441790372133255, -0.037014394998550415, -0.04975056275725365, 0.02571658417582512, -0.030932731926441193, -0.001060181064531207, -0.019138645380735397, -0.04478214308619499, -0.00473890732973814, 0.013899457640945911, -0.030774442479014397, 0.022511456161737442, -0.0010290517238900065, -0.020110638812184334, 0.0011450162855908275, 0.001466989517211914, 0.10536769777536392, -0.04909473657608032, ...],
         ...
       ]
     >
   },
   "ret.7" => %{
     "beta" => #Nx.Tensor<
       f32[512]
       [-0.041805699467659, 0.037042248994112015, 0.038585465401411057, -0.05309903994202614, 0.07353689521551132, -0.04016199707984924, 0.04907655715942383, -0.06351383775472641, -0.03901433199644089, 0.06074056401848793, 0.04307367652654648, 0.03867708891630173, -0.028187068179249763, 0.049504026770591736, 0.042205315083265305, -0.051512110978364944, 0.054953683167696, 0.05401424318552017, -0.03426476567983627, 0.0610915832221508, -0.050264328718185425, -0.051102202385663986, 0.0573478527367115, ...]
     >,
     "gamma" => #Nx.Tensor<
       f32[512]
       [0.992328405380249, 0.9798473715782166, 0.9838926196098328, 0.9932383894920349, 0.9948338270187378, 0.98109370470047, 0.9807700514793396, 1.009493350982666, 0.997860312461853, 0.9951248168945312, 0.9951954483985901, 0.98468017578125, 0.9887678623199463, 0.9658634662628174, 0.9902472496032715, 1.0008673667907715, 1.0128767490386963, 0.9547601342201233, 0.9907315969467163, 0.989186704158783, 0.9864494800567627, 0.9910844564437866, ...]
     >,
     "mean" => #Nx.Tensor<
       f32[512]
       [1.7624406814575195, 1.1456149816513062, 1.088881492614746, 1.1321114301681519, 1.0056617259979248, 1.7336934804916382, 1.1505461931228638, 1.9170235395431519, 1.5937530994415283, 1.0512514114379883, 1.1953390836715698, 0.8846156001091003, 1.6404842138290405, 1.1027706861495972, 0.8415035009384155, 1.6003998517990112, 1.059121012687683, 0.9843509197235107, 1.447360873222351, 1.0818588733673096, 1.0302938222885132, ...]
     >,
     "var" => #Nx.Tensor<
       f32[512]
       [8.289741516113281, 2.7083992958068848, 2.1886401176452637, 2.688572883605957, 2.9085545539855957, 8.149431228637695, 3.985870599746704, 10.366480827331543, 6.770506381988525, 3.070387363433838, 3.179903030395508, 2.2266383171081543, 6.598818302154541, 1.9711589813232422, 1.5258374214172363, 5.815791130065918, 3.7079122066497803, 2.057806968688965, 4.989217281341553, 2.464872360229492, ...]
     >
   }
 }}

Manipulate the images using StbImage library:

image_set_dir = "/Users/charlie/ML/datasets/oxford-iiit-pet/images"
"/Users/charlie/ML/datasets/oxford-iiit-pet/images"
{:ok, img} = StbImage.read_file("#{image_set_dir}/havanese_71.jpg")
%StbImage{data: binary, shape: shape, type: type} = StbImage.resize(img, 224, 224)
%StbImage{
  data: <<92, 92, 86, 93, 93, 88, 97, 96, 90, 98, 96, 91, 98, 95, 90, 102, 101, 95, 101, 101, 94,
    98, 97, 91, 100, 99, 94, 101, 100, 95, 101, 100, 95, 97, 96, 91, 100, 99, 94, 101, 101, 96, 100,
    99, 94, 99, 97, 92, 100, ...>>,
  shape: {224, 224, 3},
  type: {:u, 8}
}
# List of images we want to use
file_names = [
  "havanese_71.jpg",
  "yorkshire_terrier_9.jpg",
  "Sphynx_206.jpg",
  "Siamese_95.jpg",
  "Egyptian_Mau_63.jpg",
  "keeshond_175.jpg",
  "samoyed_88.jpg",
  "British_Shorthair_122.jpg",
  "Russian_Blue_20.jpg",
  "boxer_99.jpg"
]
["havanese_71.jpg", "yorkshire_terrier_9.jpg", "Sphynx_206.jpg", "Siamese_95.jpg",
 "Egyptian_Mau_63.jpg", "keeshond_175.jpg", "samoyed_88.jpg", "British_Shorthair_122.jpg",
 "Russian_Blue_20.jpg", "boxer_99.jpg"]

Resizing the images:

resized_images =
  Enum.map(file_names, fn file_name ->
    "#{image_set_dir}/#{file_name}"
    |> IO.inspect(label: file_name)
    |> StbImage.read_file!()
    |> StbImage.resize(224, 224)
  end)
havanese_71.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/havanese_71.jpg"
yorkshire_terrier_9.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/yorkshire_terrier_9.jpg"
Sphynx_206.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/Sphynx_206.jpg"
Siamese_95.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/Siamese_95.jpg"
Egyptian_Mau_63.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/Egyptian_Mau_63.jpg"
keeshond_175.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/keeshond_175.jpg"
samoyed_88.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/samoyed_88.jpg"
British_Shorthair_122.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/British_Shorthair_122.jpg"
Russian_Blue_20.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/Russian_Blue_20.jpg"
boxer_99.jpg: "/Users/charlie/ML/datasets/oxford-iiit-pet/images/boxer_99.jpg"
[
  %StbImage{
    data: <<92, 92, 86, 93, 93, 88, 97, 96, 90, 98, 96, 91, 98, 95, 90, 102, 101, 95, 101, 101, 94,
      98, 97, 91, 100, 99, 94, 101, 100, 95, 101, 100, 95, 97, 96, 91, 100, 99, 94, 101, 101, 96,
      100, 99, 94, 99, 97, 92, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<199, 176, 160, 200, 179, 162, 200, 179, 161, 203, 181, 161, 204, 183, 165, 204, 183,
      168, 205, 184, 169, 205, 185, 168, 208, 186, 171, 208, 185, 172, 208, 186, 172, 208, 186, 172,
      209, 189, 174, 208, 189, 174, 208, 189, 174, 209, 191, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<3, 3, 2, 4, 4, 2, 5, 5, 3, 6, 5, 4, 7, 6, 4, 7, 6, 4, 6, 5, 4, 7, 5, 4, 8, 5, 4, 7, 5,
      4, 7, 4, 3, 8, 5, 4, 8, 5, 5, 7, 4, 4, 7, 4, 3, 8, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
      254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
      255, 255, 255, 255, 255, 255, 255, 255, 255, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<71, 87, 41, 79, 98, 44, 90, 109, 51, 70, 92, 44, 48, 70, 38, 35, 56, 29, 25, 41, 19, 20,
      30, 14, 36, 48, 25, 44, 61, 29, 36, 53, 25, 26, 38, 18, 21, 31, 15, 36, 48, 25, 50, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<71, 72, 69, 69, 70, 64, 37, 42, 32, 4, 5, 3, 5, 5, 3, 12, 12, 10, 21, 21, 19, 60, 57,
      56, 60, 60, 54, 49, 48, 43, 63, 59, 55, 66, 63, 58, 69, 70, 63, 51, 56, 48, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<177, 167, 178, 177, 167, 178, 177, 167, 178, 176, 167, 177, 176, 168, 177, 177, 168,
      177, 177, 167, 178, 177, 167, 178, 177, 167, 178, 177, 167, 178, 177, 167, 178, 177, 167, 178,
      177, 167, 178, 177, 167, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<60, 95, 57, 59, 92, 52, 54, 88, 47, 46, 84, 41, 43, 79, 38, 39, 76, 35, 36, 70, 28, 33,
      66, 25, 33, 61, 23, 27, 59, 25, 25, 57, 25, 24, 56, 22, 27, 54, 23, 26, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  },
  %StbImage{
    data: <<9, 9, 7, 6, 6, 4, 7, 7, 5, 6, 6, 4, 7, 7, 5, 9, 9, 7, 9, 9, 7, 10, 10, 8, 11, 11, 9, 11,
      11, 9, 12, 12, 10, 14, 14, 12, 13, 13, 11, ...>>,
    shape: {224, 224, 3},
    type: {:u, 8}
  }
]

Now convert the images to tensors using StbImage.to_nx/1. The created tensor will have axes named :height, :width and :channel.

The goal is to stack the tensors, normalize and transpose their axes to the order expected by the neural network

img_tensors =
  resized_images
  |> Enum.map(&amp;StbImage.to_nx/1)
  |> Nx.stack(name: :index)
  |> Nx.divide(255.0)
  |> Nx.transpose(axes: [:index, :channels, :height, :width])
#Nx.Tensor<
  f32[index: 10][channels: 3][height: 224][width: 224]
  [
    [
      [
        [0.3607843220233917, 0.364705890417099, 0.3803921639919281, 0.3843137323856354, 0.3843137323856354, 0.4000000059604645, 0.3960784375667572, 0.3843137323856354, 0.3921568691730499, 0.3960784375667572, 0.3960784375667572, 0.3803921639919281, 0.3921568691730499, 0.3960784375667572, 0.3921568691730499, 0.38823530077934265, 0.3921568691730499, 0.40392157435417175, 0.40392157435417175, 0.40784314274787903, 0.4117647111415863, 0.4000000059604645, 0.4000000059604645, 0.40784314274787903, 0.4000000059604645, 0.4000000059604645, 0.3803921639919281, 0.3803921639919281, 0.3921568691730499, 0.4117647111415863, 0.4274509847164154, 0.42352941632270813, 0.4117647111415863, 0.3960784375667572, 0.40784314274787903, 0.40784314274787903, 0.4117647111415863, 0.4156862795352936, 0.41960784792900085, 0.4431372582912445, 0.4470588266849518, 0.45490196347236633, 0.43921568989753723, 0.45490196347236633, 0.48627451062202454, 0.4627451002597809, 0.45098039507865906, 0.4274509847164154, 0.45098039507865906, 0.47058823704719543, ...],
        ...
      ],
      ...
    ],
    ...
  ]
>

With the input data shaped as Nx.Tensors we can make predictions

defmodule Predictions do
  def single_label_prediction(predictions_batch, vocabulary) do
    IO.inspect(Nx.shape(predictions_batch), label: "Predictions batch shape")

    for prediction_tensor <- Nx.to_batched(predictions_batch) do
      {_prediction_value, prediction_label} =
        prediction_tensor
        |> Nx.to_flat_list()
        |> Enum.zip(vocabulary)
        |> Enum.max()

      prediction_label
    end
  end
end
warning: Nx.to_batched/1 is undefined or private. Did you mean:

      * to_batched/2
      * to_batched/3

  github.com/charlieroth/lab/elixir/convert_onnx_to_axon.livemd#cell:5hdzohshqhv2x7yjtfhpdfmdwlcxoyvs:5: Predictions.single_label_prediction/2
{:module, Predictions, <<70, 79, 82, 49, 0, 0, 9, ...>>, {:single_label_prediction, 2}}

Run a prediction (this will take a very long time using the default BinaryBackend so you will want to run this in a CUDA enabled environment)

tensor_of_predictions = Axon.predict(cats_v_dogs_model, cats_v_dogs_params, img_tensors)
dog_cat_vocabulary = ["dog", "cat"]
Predictions.single_label_prediction(tensor_of_predictions, dog_cat_vocabulary)