Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Chapter 2

ch-2.livemd

Chapter 2

Mix.install([
  {:nx, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:benchee, github: "bencheeorg/benchee", override: true}
])

Understanding Nx Tensors

Tensors have a type, shape, and data. They are immutable.

Type

Selecting a type with incorrect precision can cause underflow or overflow. Tensors must have a homogenous type.

  • signed integer
    • 8, 16, 32, 64
  • unsigned integer
    • 8, 16, 32, 64
  • float
    • 16, 32, 64
  • brain float
    • 16
  • complex
    • 64, 128

Shape

Tensor data is not stored as a list. Its shape is the size of each dimension.

  • 1x2 - [1, 2]
  • 2x2 - [[1, 2], [3, 4]]
  • 2x2x2 - [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]

Another name for the number of dimensions is “rank”.

Data

Stored as a byte array. The bytes are interpreted as a nested list of values depending on the tensor’s shape and type.

Nx operates at the byte level so endianness matters.

defmodule A do
  import Nx.Defn

  defn(adds_one(x), do: x |> Nx.add(1) |> print_expr())
end
A.adds_one(1)
defmodule Softmax do
  import Nx.Defn

  defn(softmax(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n)))
  # Always faster in Elixir, tiny improvement in EXLA
  defn softmax2(n) do
    m = Nx.exp(n)
    m / Nx.sum(m)
  end
end
{tensor, _} =
  42
  |> Nx.Random.key()
  |> Nx.Random.uniform(shape: {1_000_000})

Benchee.run(%{
  "JIT with EXLA v1" => fn -> apply(EXLA.jit(&Softmax.softmax/1), [tensor]) end,
  "JIT with EXLA v2" => fn -> apply(EXLA.jit(&Softmax.softmax2/1), [tensor]) end,
  "Regular Elixir v1" => fn -> Softmax.softmax(tensor) end,
  "Regular Elixir v2" => fn -> Softmax.softmax2(tensor) end
})

Setting the compiler to Nx.Defn.global_default_options(compiler: EXLA) will make the “regular” Elixir version just as fast.

Can also configure the default backend Nx.default_backend(EXLA.Backend)

Backends are slower, but allow for rapid prototyping and don’t require modules and numerical definitions.