Chapter 2
Mix.install([
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:benchee, github: "bencheeorg/benchee", override: true}
])
Understanding Nx Tensors
Tensors have a type, shape, and data. They are immutable.
Type
Selecting a type with incorrect precision can cause underflow or overflow. Tensors must have a homogenous type.
-
signed integer
-
8
,16
,32
,64
-
-
unsigned integer
-
8
,16
,32
,64
-
-
float
-
16
,32
,64
-
-
brain float
-
16
-
-
complex
-
64
,128
-
Shape
Tensor data is not stored as a list. Its shape is the size of each dimension.
-
1x2 -
[1, 2]
-
2x2 -
[[1, 2], [3, 4]]
-
2x2x2 -
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
Another name for the number of dimensions is “rank”.
Data
Stored as a byte array. The bytes are interpreted as a nested list of values depending on the tensor’s shape and type.
Nx operates at the byte level so endianness matters.
defmodule A do
import Nx.Defn
defn(adds_one(x), do: x |> Nx.add(1) |> print_expr())
end
A.adds_one(1)
defmodule Softmax do
import Nx.Defn
defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
# Always faster in Elixir, tiny improvement in EXLA
defn softmax2(n) do
m = Nx.exp(n)
m / Nx.sum(m)
end
end
{tensor, _} =
42
|> Nx.Random.key()
|> Nx.Random.uniform(shape: {1_000_000})
Benchee.run(%{
"JIT with EXLA v1" => fn -> apply(EXLA.jit(&Softmax.softmax/1), [tensor]) end,
"JIT with EXLA v2" => fn -> apply(EXLA.jit(&Softmax.softmax2/1), [tensor]) end,
"Regular Elixir v1" => fn -> Softmax.softmax(tensor) end,
"Regular Elixir v2" => fn -> Softmax.softmax2(tensor) end
})
Setting the compiler to Nx.Defn.global_default_options(compiler: EXLA)
will make the “regular” Elixir version just as fast.
Can also configure the default backend Nx.default_backend(EXLA.Backend)
Backends are slower, but allow for rapid prototyping and don’t require modules and numerical definitions.