Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Axon: Acceleration

elixir/axon_nx/accelerating_axon.livemd

Axon: Acceleration

Mix.install([
  {:axon, "~> 0.5.1"},
  {:exla, "~> 0.5.2"},
  {:nx, "~> 0.5.2"},
  {:benchee, github: "bencheeorg/benchee", branch: "main"},
  {:kino, "~> 0.9.0", override: true}
])
* Getting benchee (https://github.com/bencheeorg/benchee.git - origin/main)
remote: Enumerating objects: 7539, done.        
remote: Counting objects: 100% (789/789), done.        
remote: Compressing objects: 100% (352/352), done.        
remote: Total 7539 (delta 423), reused 666 (delta 394), pack-reused 6750        
Resolving Hex dependencies...
Dependency resolution completed:
New:
  axon 0.5.1
  complex 0.5.0
  deep_merge 1.0.0
  elixir_make 0.7.6
  exla 0.5.2
  kino 0.9.0
  nx 0.5.2
  statistex 1.0.0
  table 0.1.2
  telemetry 1.2.1
  xla 0.4.4
* Getting axon (Hex package)
* Getting exla (Hex package)
* Getting nx (Hex package)
* Getting kino (Hex package)
* Getting table (Hex package)
* Getting deep_merge (Hex package)
* Getting statistex (Hex package)
* Getting complex (Hex package)
* Getting telemetry (Hex package)
* Getting elixir_make (Hex package)
* Getting xla (Hex package)
==> deep_merge
Compiling 2 files (.ex)
Generated deep_merge app
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> statistex
Compiling 3 files (.ex)
Generated statistex app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 39 files (.ex)
Generated kino app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> benchee
Compiling 44 files (.ex)
Generated benchee app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/charlie/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/charlie/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.4/d96b6d6021536e841b65d5bccffbb74b/deps/exla/cache
Using libexla.so from /Users/charlie/Library/Caches/xla/exla/elixir-1.14.2-erts-13.0.4-xla-0.4.4-exla-0.5.2-dsbqjpzlwbrtb4mbts7ynstq2m/libexla.so
Compiling 21 files (.ex)
Generated exla app
:ok

Using Nx Compilers in Axon

Axon is built entirely on top of Nx numerical definitions defn. Functions declared with defn tell Nx to use JIT compilation and execute numerical definitions with an available Nx compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. Nx (as of now) has 2 official supported compilers/backends on top of the default BinaryBackend:

  • EXLA - Acceleration via Google’s XLA project
  • TorchX - Bindings to LibTorch

By default, Nx and Axon run all computations using the BinaryBackend, which is a pure Elixir implementation. The BinaryBackend is very slow therefore developers should use one of the available accelerated libraries

We can compare the backend speeds with some benchmarks

model =
  Axon.input("data")
  |> Axon.dense(32)
  |> Axon.relu()
  |> Axon.dense(1)
  |> Axon.softmax()
#Axon<
  inputs: %{"data" => nil}
  outputs: "softmax_0"
  nodes: 5
>

Axon will respect the default defn compilation options. You can set compilation options globally, per-process or per-build

# Sets the global compilation options
Nx.Defn.global_default_options(compiler: EXLA)
# Sets the process-level compilation options
Nx.Defn.default_options(compiler: EXLA)
# Configure the model to be built with a specific compiler
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
Nx.Defn.global_default_options([])
Nx.Defn.default_options([])
[]
key = Nx.Random.key(1701)
{inputs, _key} = Nx.Random.uniform(key, shape: {2, 18}, type: :f32)
{#Nx.Tensor<
   f32[2][18]
   [
     [0.27114367485046387, 0.20867478847503662, 0.8550657033920288, 0.9538747072219849, 0.5943576097488403, 0.6767340898513794, 0.7280173301696777, 0.8957382440567017, 0.49435389041900635, 0.8597968816757202, 0.9159774780273438, 0.41785871982574463, 0.5218784809112549, 0.32611608505249023, 0.6749540567398071, 0.3122551441192627, 0.9772067070007324, 0.17272377014160156],
     [0.7656633853912354, 0.3271660804748535, 0.807614803314209, 0.6883835792541504, 0.41343605518341064, 0.9727122783660889, 0.27507996559143066, 0.28299689292907715, 0.5320568084716797, 0.6978424787521362, 0.525448203086853, 0.21462464332580566, 0.3494069576263428, 0.7044590711593628, 0.6598044633865356, 0.7532706260681152, 0.979314923286438, 0.8204587697982788]
   ]
 >,
 #Nx.Tensor<
   u32[2]
   [56197195, 1801093307]
 >}
{init_fn, predict_fn} = Axon.build(model)
{#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>,
 #Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>}
params = init_fn.(inputs, %{})
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[32]
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[18][32]
      [
        [-0.3051754832267761, 0.25501424074172974, -0.05403277277946472, 0.09163141250610352, -0.1316809058189392, 0.1322190761566162, 0.1101188063621521, 0.14183634519577026, 0.28990280628204346, -0.11901968717575073, -0.024596571922302246, -0.043264031410217285, -0.007663100957870483, 0.1158941388130188, -0.29987192153930664, 0.056195229291915894, -0.31137290596961975, -0.20324155688285828, -0.27872779965400696, 0.22537386417388916, 0.336606502532959, -0.0597720742225647, -0.1954112946987152, -0.0635291039943695, 0.284249484539032, -0.3126755952835083, 0.25980281829833984, -0.24364405870437622, -0.06416201591491699, 0.31739819049835205, 0.15450125932693481, -0.07652541995048523],
        [0.11587768793106079, -0.09446549415588379, -0.18164658546447754, 0.07257139682769775, 0.2741078734397888, -0.18522696197032928, 0.12015101313591003, -0.14292685687541962, -0.32305973768234253, 0.07889625430107117, -0.043657660484313965, 0.26575928926467896, 0.29837238788604736, 0.11755535006523132, 0.05597198009490967, ...],
        ...
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[32][1]
      [
        [-0.3466041684150696],
        [-0.3015868067741394],
        [-0.28144195675849915],
        [0.3505516052246094],
        [0.06784883141517639],
        [-0.3983905017375946],
        [-0.2863359749317169],
        [0.3189737796783447],
        [-0.15723109245300293],
        [0.41920775175094604],
        [-0.24454393982887268],
        [0.08108031749725342],
        [-0.257163941860199],
        [0.16668391227722168],
        [0.3904075026512146],
        [0.2851768136024475],
        [-0.3781568109989166],
        [0.1311066746711731],
        [-0.3213688135147095],
        [0.3913983106613159],
        [-0.3420376181602478],
        [-0.32796016335487366],
        [0.41802865266799927],
        [-0.31250956654548645],
        [-0.07970893383026123],
        [-0.19813349843025208],
        [0.350238561630249],
        [-0.2048673778772354],
        [0.3699566721916199],
        [-0.16490444540977478],
        [-0.14745160937309265],
        [-0.3433813750743866]
      ]
    >
  }
}
{exla_init_fn, exla_predict_fn} = Axon.build(model, compiler: EXLA)
{#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>,
 #Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>}

If you inspect the Nx.Tensor structure that are returned we can see that they are using the EXLA.Backend as the host to run the computation

Now if we benchmark the different compiled functions we can see how they perform in terms of memory usage and runtime

Here is a sample output from a previous run to avoid having to run the benchmark on every livebook load

Initialization Run
Benchmarking elixir init ...
Benchmarking exla init ...

Name                  ips        average  deviation         median         99th %
exla init         19.67 K      0.0508 ms    ±24.80%      0.0476 ms      0.0996 ms
elixir init        0.36 K        2.80 ms     ±5.30%        2.78 ms        3.17 ms

Comparison: 
exla init         19.67 K
elixir init        0.36 K - 55.06x slower +2.75 ms

Memory usage statistics:

Name                average  deviation         median         99th %
exla init        0.00875 MB     ±0.00%     0.00875 MB     0.00875 MB
elixir init         3.91 MB     ±0.02%        3.91 MB        3.91 MB

Comparison: 
exla init        0.00875 MB
elixir init         3.91 MB - 447.19x memory usage +3.90 MB
Prediction Run
Benchmarking elixir predict ...
Benchmarking exla predict ...

Name                     ips        average  deviation         median         99th %
exla predict         32.39 K       30.88 μs    ±55.48%       28.88 μs       90.17 μs
elixir predict        3.68 K      271.39 μs    ±12.27%      265.12 μs      359.12 μs

Comparison: 
exla predict         32.39 K
elixir predict        3.68 K - 8.79x slower +240.51 μs

Memory usage statistics:

Name              Memory usage
exla predict          12.79 KB
elixir predict       890.94 KB - 69.66x memory usage +878.15 KB

**All measurements for memory usage were the same**
# Benchee.run(
#   %{
#     "elixir init" => fn -> init_fn.(inputs, %{}) end,
#     "exla init" => fn -> exla_init_fn.(inputs, %{}) end
#   },
#   time: 10,
#   memory_time: 5,
#   warmup: 2
# )

# Benchee.run(
#   %{
#     "elixir predict" => fn -> predict_fn.(params, inputs) end,
#     "exla predict" => fn -> exla_predict_fn.(params, inputs) end
#   },
#   time: 10,
#   memory_time: 5,
#   warmup: 2
# )
nil