Chapter 2: Get Comfortable with Nx
Mix.install([
{:nx, "~> 0.7.2"},
{:exla, "~> 0.7.2"},
{:benchee, "~> 1.3"},
{:kino_benchee, "~> 0.1.0"}
])
Understanding Nx Tensors
https://hexdocs.pm/nx/Nx.Tensor.html
The simplest way to create a tensor with Nx
is with Nx.tensor/2
Every Nx.Tensor
has three properties, when inspecting it’s contents:
 Numeric Type
 Shape
 Data
In comparison to other ML libraries like PyTorch or TensorFlow, a Nx.Tensor
is immutable and operations on tensors do not change the tensor’s underlying properties. Nx
overcomes the obvious performance hit of producing a new tensor on every operation by introducing a programming model that enables Nx
operator fusion.
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
s64[2][3]
[
[1, 2, 3],
[4, 5, 6]
]
>
Using Nx Operations
All of the Nx
operators can be explored here: https://hexdocs.pm/nx/Nx.html
The four common types of operations to be comfortable with first are:
 Shape and type operations
 Elementwise unary operations
 Elementwise binary operations
 Reductions
a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
s64[3]
[1, 2, 3]
>
Shape and Type Operations
Nx.reshape/2
and Nx.as_type/2
are often used to shape data and convert numerical types for an algorithm
a
> Nx.as_type({:f, 32})
> Nx.reshape({1, 3, 1})
#Nx.Tensor<
f32[1][3][1]
[
[
[1.0],
[2.0],
[3.0]
]
]
>
Elementwise Unary Operations
These operations work on the flattened representation of tensor data while still preserving the tensor’s shape
a = Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
#Nx.Tensor<
s64[2][2][3]
[
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
]
>
Nx.abs(a)
#Nx.Tensor<
s64[2][2][3]
[
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
]
>
Elementwise Binary Operations
Binary operations are addition, subtraction, multipcation and division
a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[6, 7, 8], [9, 10, 11]])
#Nx.Tensor<
s64[2][3]
[
[6, 7, 8],
[9, 10, 11]
]
>
Nx.add(a, b)
#Nx.Tensor<
s64[2][3]
[
[7, 9, 11],
[13, 15, 17]
]
>
Nx.multiply(a, b)
#Nx.Tensor<
s64[2][3]
[
[6, 14, 24],
[36, 50, 66]
]
>
If you have tensor’s with two different shapes and you attempt to performa binary operation on them, Nx
will attempt to broadcast your tensor together. Broadcasting is the process of repeating an operation over the dimensions of two tensors to make their shapes compatible. Two shapes can be broadcasted together only when the following conditions are met:
 One of the shapes is a scalar, or

Corresponding dimensions have the same size OR one of the dimensions is size
1
.
Nx.add(5, Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
[6, 7, 8]
>
Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([[4, 5, 6], [7, 8, 9]]))
#Nx.Tensor<
s64[2][3]
[
[5, 7, 9],
[8, 10, 12]
]
>
Reductions
Nx
offers a number of reductions that allow you to compute aggregates over an entire tensor or over specific axes
revs = Nx.tensor([85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51])
#Nx.Tensor<
s64[12]
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
>
Nx.sum(revs)
#Nx.Tensor<
s64
647
>
revs =
Nx.tensor(
[
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
],
names: [:year, :month]
)
#Nx.Tensor<
s64[year: 4][month: 12]
[
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
[85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
]
>
# Total revenue per year
Nx.sum(revs, axes: [:year])
#Nx.Tensor<
s64[month: 12]
[340, 304, 168, 136, 184, 92, 208, 396, 88, 128, 340, 204]
>
# Total revenue per month
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
s64[year: 4]
[647, 647, 647, 647]
>
Going from def to defn
An Nx
numerical definition is an Elixir function that will be justintime (JIT) compiled using a valid Nx
compiler. By using defn
, a numerical definition is created that can be JITcompiled to the CPU or GPU.
Numerical definitions make use of a multistage programming model. On function invocation, rather than executing the function, Nx
computes an expression representation of your program, and then gives that expression to an Nx
compiler such as EXLA
. The compiler traverses the expression, and compiles an optimized program from the given expression, which can be executed on the CPU or GPU.
defmodule MyModule do
import Nx.Defn
defn adds_one(x) do
Nx.add(x, 1) > print_expr()
end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
s64[3]
Nx.Defn.Expr
parameter a:0 s64[3]
b = add 1, a s64[3]
>
#Nx.Tensor<
s64[3]
[2, 3, 4]
>
defmodule Softmax do
import Nx.Defn
defn softmax(n) do
Nx.exp(n) / Nx.sum(Nx.exp(n))
end
end
{:module, Softmax, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
key = Nx.Random.key(42)
{tensor, _key} = Nx.Random.uniform(key, shape: {1_000_000})
Benchee.run(
%{
"JIT with EXLA" => fn >
apply(EXLA.jit(&Softmax.softmax/1), [tensor])
end,
"Regular Elixir" => fn >
Softmax.softmax(tensor)
end
},
time: 10
)
Warning: the benchmark JIT with EXLA is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Warning: the benchmark Regular Elixir is using an evaluated function.
Evaluated functions perform slower than compiled functions.
You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs
Operating System: Linux
CPU Information: 13th Gen Intel(R) Core(TM) i913900H
Number of Available Cores: 20
Available memory: 30.99 GB
Elixir 1.16.2
Erlang 26.2.4
JIT enabled: true
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s
Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...
Name ips average deviation median 99th %
JIT with EXLA 523.28 1.91 ms ±35.58% 1.93 ms 4.00 ms
Regular Elixir 3.26 306.83 ms ±1.08% 307.12 ms 313.06 ms
Comparison:
JIT with EXLA 523.28
Regular Elixir 3.26  160.56x slower +304.92 ms
Backend or Compiler?
The relationship between backends and compilers is kind of like the relationship between interpreted programming languages and compiled programming languages.
Nx
backends are implementations of the Nx
library which eagerly evaluate Nx
functions. The default Nx
backend is Nx.BinaryBackend
, which uses pure Elixir to manipulate tensors. Backends are slower, but you can more rapidly prototype as you don’t have to structure your code into modules and numerical defintions. You can set a default backend using Nx.default_backend/1
or in your application’s configuration:
config :nx, default_backend: EXLA.Backend
Compilers implement the multistage programming model. Compilers are often more performant; however, they require a stricter programming model. The compiler used can be configured with
Nx.Defn.global_default_options(compiler: EXLA)
or
config :nx, :default_options, [compiler: EXLA]
There are some pitfalls setting a default compiler for your application. To avoid these, it’s often recommended to only set a default backend, and then explicitly JITcompile functions when you deem it necessary.