Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 2: Get Comfortable with Nx

ch2_get_comfortable_with_nx.livemd

Chapter 2: Get Comfortable with Nx

Mix.install([
  {:nx, "~> 0.7.2"},
  {:exla, "~> 0.7.2"},
  {:benchee, "~> 1.3"},
  {:kino_benchee, "~> 0.1.0"}
])

Understanding Nx Tensors

https://hexdocs.pm/nx/Nx.Tensor.html

The simplest way to create a tensor with Nx is with Nx.tensor/2

Every Nx.Tensor has three properties, when inspecting it’s contents:

  1. Numeric Type
  2. Shape
  3. Data

In comparison to other ML libraries like PyTorch or TensorFlow, a Nx.Tensor is immutable and operations on tensors do not change the tensor’s underlying properties. Nx overcomes the obvious performance hit of producing a new tensor on every operation by introducing a programming model that enables Nx operator fusion.

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

Using Nx Operations

All of the Nx operators can be explored here: https://hexdocs.pm/nx/Nx.html

The four common types of operations to be comfortable with first are:

  1. Shape and type operations
  2. Element-wise unary operations
  3. Element-wise binary operations
  4. Reductions
a = Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>

Shape and Type Operations

Nx.reshape/2 and Nx.as_type/2 are often used to shape data and convert numerical types for an algorithm

a
|> Nx.as_type({:f, 32})
|> Nx.reshape({1, 3, 1})
#Nx.Tensor<
  f32[1][3][1]
  [
    [
      [1.0],
      [2.0],
      [3.0]
    ]
  ]
>

Element-wise Unary Operations

These operations work on the flattened representation of tensor data while still preserving the tensor’s shape

a = Nx.tensor([[[-1, -2, -3], [-4, -5, -6]], [[1, 2, 3], [4, 5, 6]]])
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [-1, -2, -3],
      [-4, -5, -6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>
Nx.abs(a)
#Nx.Tensor<
  s64[2][2][3]
  [
    [
      [1, 2, 3],
      [4, 5, 6]
    ],
    [
      [1, 2, 3],
      [4, 5, 6]
    ]
  ]
>

Element-wise Binary Operations

Binary operations are addition, subtraction, multipcation and division

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
b = Nx.tensor([[6, 7, 8], [9, 10, 11]])
#Nx.Tensor<
  s64[2][3]
  [
    [6, 7, 8],
    [9, 10, 11]
  ]
>
Nx.add(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [7, 9, 11],
    [13, 15, 17]
  ]
>
Nx.multiply(a, b)
#Nx.Tensor<
  s64[2][3]
  [
    [6, 14, 24],
    [36, 50, 66]
  ]
>

If you have tensor’s with two different shapes and you attempt to performa binary operation on them, Nx will attempt to broadcast your tensor together. Broadcasting is the process of repeating an operation over the dimensions of two tensors to make their shapes compatible. Two shapes can be broadcasted together only when the following conditions are met:

  1. One of the shapes is a scalar, or
  2. Corresponding dimensions have the same size OR one of the dimensions is size 1.
Nx.add(5, Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  [6, 7, 8]
>
Nx.add(Nx.tensor([1, 2, 3]), Nx.tensor([[4, 5, 6], [7, 8, 9]]))
#Nx.Tensor<
  s64[2][3]
  [
    [5, 7, 9],
    [8, 10, 12]
  ]
>

Reductions

Nx offers a number of reductions that allow you to compute aggregates over an entire tensor or over specific axes

revs = Nx.tensor([85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51])
#Nx.Tensor<
  s64[12]
  [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
>
Nx.sum(revs)
#Nx.Tensor<
  s64
  647
>
revs =
  Nx.tensor(
    [
      [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
      [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
      [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
      [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
    ],
    names: [:year, :month]
  )
#Nx.Tensor<
  s64[year: 4][month: 12]
  [
    [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
    [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
    [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51],
    [85, 76, 42, 34, 46, 23, 52, 99, 22, 32, 85, 51]
  ]
>
# Total revenue per year
Nx.sum(revs, axes: [:year])
#Nx.Tensor<
  s64[month: 12]
  [340, 304, 168, 136, 184, 92, 208, 396, 88, 128, 340, 204]
>
# Total revenue per month
Nx.sum(revs, axes: [:month])
#Nx.Tensor<
  s64[year: 4]
  [647, 647, 647, 647]
>

Going from def to defn

An Nx numerical definition is an Elixir function that will be just-in-time (JIT) compiled using a valid Nx compiler. By using defn, a numerical definition is created that can be JIT-compiled to the CPU or GPU.

Numerical definitions make use of a multi-stage programming model. On function invocation, rather than executing the function, Nx computes an expression representation of your program, and then gives that expression to an Nx compiler such as EXLA. The compiler traverses the expression, and compiles an optimized program from the given expression, which can be executed on the CPU or GPU.

defmodule MyModule do
  import Nx.Defn

  defn adds_one(x) do
    Nx.add(x, 1) |> print_expr()
  end
end
{:module, MyModule, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
MyModule.adds_one(Nx.tensor([1, 2, 3]))
#Nx.Tensor<
  s64[3]
  
  Nx.Defn.Expr
  parameter a:0   s64[3]
  b = add 1, a    s64[3]
>
#Nx.Tensor<
  s64[3]
  [2, 3, 4]
>
defmodule Softmax do
  import Nx.Defn

  defn softmax(n) do
    Nx.exp(n) / Nx.sum(Nx.exp(n))
  end
end
{:module, Softmax, <<70, 79, 82, 49, 0, 0, 9, ...>>, true}
key = Nx.Random.key(42)
{tensor, _key} = Nx.Random.uniform(key, shape: {1_000_000})

Benchee.run(
  %{
    "JIT with EXLA" => fn ->
      apply(EXLA.jit(&amp;Softmax.softmax/1), [tensor])
    end,
    "Regular Elixir" => fn ->
      Softmax.softmax(tensor)
    end
  },
  time: 10
)
Warning: the benchmark JIT with EXLA is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark Regular Elixir is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: 13th Gen Intel(R) Core(TM) i9-13900H
Number of Available Cores: 20
Available memory: 30.99 GB
Elixir 1.16.2
Erlang 26.2.4
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 24 s

Benchmarking JIT with EXLA ...
Benchmarking Regular Elixir ...
Calculating statistics...
Formatting results...

Name                     ips        average  deviation         median         99th %
JIT with EXLA         523.28        1.91 ms    ±35.58%        1.93 ms        4.00 ms
Regular Elixir          3.26      306.83 ms     ±1.08%      307.12 ms      313.06 ms

Comparison: 
JIT with EXLA         523.28
Regular Elixir          3.26 - 160.56x slower +304.92 ms

Backend or Compiler?

The relationship between backends and compilers is kind of like the relationship between interpreted programming languages and compiled programming languages.

Nx backends are implementations of the Nx library which eagerly evaluate Nx functions. The default Nx backend is Nx.BinaryBackend, which uses pure Elixir to manipulate tensors. Backends are slower, but you can more rapidly prototype as you don’t have to structure your code into modules and numerical defintions. You can set a default backend using Nx.default_backend/1 or in your application’s configuration:

config :nx, default_backend: EXLA.Backend

Compilers implement the multi-stage programming model. Compilers are often more performant; however, they require a stricter programming model. The compiler used can be configured with

Nx.Defn.global_default_options(compiler: EXLA)

or

config :nx, :default_options, [compiler: EXLA]

There are some pitfalls setting a default compiler for your application. To avoid these, it’s often recommended to only set a default backend, and then explicitly JIT-compile functions when you deem it necessary.