Axon: Acceleration
Mix.install([
{:axon, "~> 0.5.1"},
{:exla, "~> 0.5.2"},
{:nx, "~> 0.5.2"},
{:benchee, github: "bencheeorg/benchee", branch: "main"},
{:kino, "~> 0.9.0", override: true}
])
* Getting benchee (https://github.com/bencheeorg/benchee.git - origin/main)
remote: Enumerating objects: 7539, done.
remote: Counting objects: 100% (789/789), done.
remote: Compressing objects: 100% (352/352), done.
remote: Total 7539 (delta 423), reused 666 (delta 394), pack-reused 6750
Resolving Hex dependencies...
Dependency resolution completed:
New:
axon 0.5.1
complex 0.5.0
deep_merge 1.0.0
elixir_make 0.7.6
exla 0.5.2
kino 0.9.0
nx 0.5.2
statistex 1.0.0
table 0.1.2
telemetry 1.2.1
xla 0.4.4
* Getting axon (Hex package)
* Getting exla (Hex package)
* Getting nx (Hex package)
* Getting kino (Hex package)
* Getting table (Hex package)
* Getting deep_merge (Hex package)
* Getting statistex (Hex package)
* Getting complex (Hex package)
* Getting telemetry (Hex package)
* Getting elixir_make (Hex package)
* Getting xla (Hex package)
==> deep_merge
Compiling 2 files (.ex)
Generated deep_merge app
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> statistex
Compiling 3 files (.ex)
Generated statistex app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> kino
Compiling 39 files (.ex)
Generated kino app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> benchee
Compiling 44 files (.ex)
Generated benchee app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/charlie/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/charlie/Library/Caches/mix/installs/elixir-1.14.2-erts-13.0.4/d96b6d6021536e841b65d5bccffbb74b/deps/exla/cache
Using libexla.so from /Users/charlie/Library/Caches/xla/exla/elixir-1.14.2-erts-13.0.4-xla-0.4.4-exla-0.5.2-dsbqjpzlwbrtb4mbts7ynstq2m/libexla.so
Compiling 21 files (.ex)
Generated exla app
:ok
Using Nx Compilers in Axon
Axon is built entirely on top of Nx
numerical definitions defn
. Functions declared with defn
tell Nx to use JIT compilation and execute numerical definitions with an available Nx
compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. Nx (as of now) has 2 official supported compilers/backends on top of the default BinaryBackend
:
- EXLA - Acceleration via Google’s XLA project
- TorchX - Bindings to LibTorch
By default, Nx
and Axon
run all computations using the BinaryBackend
, which is a pure Elixir implementation. The BinaryBackend
is very slow therefore developers should use one of the available accelerated libraries
We can compare the backend speeds with some benchmarks
model =
Axon.input("data")
|> Axon.dense(32)
|> Axon.relu()
|> Axon.dense(1)
|> Axon.softmax()
#Axon<
inputs: %{"data" => nil}
outputs: "softmax_0"
nodes: 5
>
Axon will respect the default defn
compilation options. You can set compilation options globally, per-process or per-build
# Sets the global compilation options
Nx.Defn.global_default_options(compiler: EXLA)
# Sets the process-level compilation options
Nx.Defn.default_options(compiler: EXLA)
# Configure the model to be built with a specific compiler
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
Nx.Defn.global_default_options([])
Nx.Defn.default_options([])
[]
key = Nx.Random.key(1701)
{inputs, _key} = Nx.Random.uniform(key, shape: {2, 18}, type: :f32)
{#Nx.Tensor<
f32[2][18]
[
[0.27114367485046387, 0.20867478847503662, 0.8550657033920288, 0.9538747072219849, 0.5943576097488403, 0.6767340898513794, 0.7280173301696777, 0.8957382440567017, 0.49435389041900635, 0.8597968816757202, 0.9159774780273438, 0.41785871982574463, 0.5218784809112549, 0.32611608505249023, 0.6749540567398071, 0.3122551441192627, 0.9772067070007324, 0.17272377014160156],
[0.7656633853912354, 0.3271660804748535, 0.807614803314209, 0.6883835792541504, 0.41343605518341064, 0.9727122783660889, 0.27507996559143066, 0.28299689292907715, 0.5320568084716797, 0.6978424787521362, 0.525448203086853, 0.21462464332580566, 0.3494069576263428, 0.7044590711593628, 0.6598044633865356, 0.7532706260681152, 0.979314923286438, 0.8204587697982788]
]
>,
#Nx.Tensor<
u32[2]
[56197195, 1801093307]
>}
{init_fn, predict_fn} = Axon.build(model)
{#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>,
#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>}
params = init_fn.(inputs, %{})
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[32]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[18][32]
[
[-0.3051754832267761, 0.25501424074172974, -0.05403277277946472, 0.09163141250610352, -0.1316809058189392, 0.1322190761566162, 0.1101188063621521, 0.14183634519577026, 0.28990280628204346, -0.11901968717575073, -0.024596571922302246, -0.043264031410217285, -0.007663100957870483, 0.1158941388130188, -0.29987192153930664, 0.056195229291915894, -0.31137290596961975, -0.20324155688285828, -0.27872779965400696, 0.22537386417388916, 0.336606502532959, -0.0597720742225647, -0.1954112946987152, -0.0635291039943695, 0.284249484539032, -0.3126755952835083, 0.25980281829833984, -0.24364405870437622, -0.06416201591491699, 0.31739819049835205, 0.15450125932693481, -0.07652541995048523],
[0.11587768793106079, -0.09446549415588379, -0.18164658546447754, 0.07257139682769775, 0.2741078734397888, -0.18522696197032928, 0.12015101313591003, -0.14292685687541962, -0.32305973768234253, 0.07889625430107117, -0.043657660484313965, 0.26575928926467896, 0.29837238788604736, 0.11755535006523132, 0.05597198009490967, ...],
...
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.0]
>,
"kernel" => #Nx.Tensor<
f32[32][1]
[
[-0.3466041684150696],
[-0.3015868067741394],
[-0.28144195675849915],
[0.3505516052246094],
[0.06784883141517639],
[-0.3983905017375946],
[-0.2863359749317169],
[0.3189737796783447],
[-0.15723109245300293],
[0.41920775175094604],
[-0.24454393982887268],
[0.08108031749725342],
[-0.257163941860199],
[0.16668391227722168],
[0.3904075026512146],
[0.2851768136024475],
[-0.3781568109989166],
[0.1311066746711731],
[-0.3213688135147095],
[0.3913983106613159],
[-0.3420376181602478],
[-0.32796016335487366],
[0.41802865266799927],
[-0.31250956654548645],
[-0.07970893383026123],
[-0.19813349843025208],
[0.350238561630249],
[-0.2048673778772354],
[0.3699566721916199],
[-0.16490444540977478],
[-0.14745160937309265],
[-0.3433813750743866]
]
>
}
}
{exla_init_fn, exla_predict_fn} = Axon.build(model, compiler: EXLA)
{#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>,
#Function<135.4924062/2 in Nx.Defn.Compiler.fun/2>}
If you inspect the Nx.Tensor
structure that are returned we can see that they are using the EXLA.Backend
as the host to run the computation
Now if we benchmark the different compiled functions we can see how they perform in terms of memory usage and runtime
Here is a sample output from a previous run to avoid having to run the benchmark on every livebook load
Initialization Run
Benchmarking elixir init ...
Benchmarking exla init ...
Name ips average deviation median 99th %
exla init 19.67 K 0.0508 ms ±24.80% 0.0476 ms 0.0996 ms
elixir init 0.36 K 2.80 ms ±5.30% 2.78 ms 3.17 ms
Comparison:
exla init 19.67 K
elixir init 0.36 K - 55.06x slower +2.75 ms
Memory usage statistics:
Name average deviation median 99th %
exla init 0.00875 MB ±0.00% 0.00875 MB 0.00875 MB
elixir init 3.91 MB ±0.02% 3.91 MB 3.91 MB
Comparison:
exla init 0.00875 MB
elixir init 3.91 MB - 447.19x memory usage +3.90 MB
Prediction Run
Benchmarking elixir predict ...
Benchmarking exla predict ...
Name ips average deviation median 99th %
exla predict 32.39 K 30.88 μs ±55.48% 28.88 μs 90.17 μs
elixir predict 3.68 K 271.39 μs ±12.27% 265.12 μs 359.12 μs
Comparison:
exla predict 32.39 K
elixir predict 3.68 K - 8.79x slower +240.51 μs
Memory usage statistics:
Name Memory usage
exla predict 12.79 KB
elixir predict 890.94 KB - 69.66x memory usage +878.15 KB
**All measurements for memory usage were the same**
# Benchee.run(
# %{
# "elixir init" => fn -> init_fn.(inputs, %{}) end,
# "exla init" => fn -> exla_init_fn.(inputs, %{}) end
# },
# time: 10,
# memory_time: 5,
# warmup: 2
# )
# Benchee.run(
# %{
# "elixir predict" => fn -> predict_fn.(params, inputs) end,
# "exla predict" => fn -> exla_predict_fn.(params, inputs) end
# },
# time: 10,
# memory_time: 5,
# warmup: 2
# )
nil