Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx

books/ai/nx.livemd

Nx

Mix.install([
  {:nx, "~> 0.4.1"}
])
:ok

Introduction

Nx (Numerical Elixir) is a library for creating and manipulating multidimensional arrays. It is intended to serve as the core of numerical computing and data science in the Elixir ecosystem. Programming in Nx requires a bit of a different way of thinking. If you’re familiar with the Python ecosystem, Nx will remind you a lot of NumPy. While this is true, there are some key differences - mostly due to the difference in language constructs between Elixir and Python. As one example, Nx tensors are completely immutable.

At the core of Nx is the Nx.Tensor. The Nx.Tensor is analogous to the NumPy ndarray or TensorFlow/PyTorch Tensor objects. It is the main data structure the Nx library is designed to manipulate. All of the Nx functionality such as gradient computations, just-in-time compilation, pluggable backends, etc. are built on top of implementations of the Nx.Tensor behavior.

Installation

Nx is a regular Elixir library, so you can install it in the same way you would any other Elixir library.

Mix.install([
  {:nx, "~> 0.1.0"}
])

Lists vs. Tensors

When you first create and inspect a tensor, you’re probably inclined to think of it as a list or a nested list of numbers:

Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

That line of thinking is reasonable - after all, inspecting the values yields a nested list representation of the tensor! The truth, though, is that this visual representation is just a matter of convenience. Thinking of a tensor as a nested list is misleading and might cause you to have a difficult time grasping some of the fundamental concepts in Nx.

The Nx.Tensor is a data structure with four key fields:

  • :data
  • :shape
  • :type
  • :names

Tensors have data

In order to perform any computations at all, tensors need to have some underlying data which contain its values. The most common way to represent a tensor’s data is with a flat VM binary - essentially just an array of bytes. This is an important implementation detail; Nx mostly operates on the raw bytes which represent individual values in a tensor. Those values are stored in a flat container - Nx doesn’t operate on lists or nested lists.

Binaries are just C byte arrays, so we’re able to perform some very efficient operations on large tensors. While this gives us a nice performance boost, it also constrains us. Our tensor operations need to know what type the byte values represent in order to perform operations correctly. This means every value in a tensor must have the same type.

Finally, the choice of representing tensor data as a flat binary leads to some interesting (and annoying) scenarios to consider. At the very least, we need to be conscious of endianness - you can’t guarantee the raw byte values of a tensor will be interpreted the same way on different machines.

Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.to_binary()
<<1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5,
  0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0>>
Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.to_binary()
<<0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64>>

Tensors have shape

The “nested list” representation you see when inspecting a tensor is actually a manifestation of its shape. A tensor’s shape is best described as the size of each dimension. While two tensors might have the same underlying data, they can have different shapes, which fundamentally change the nature of the operations performed on them.

We describe a tensor’s shape with a tuple of integers: {size_d1, size_d2, …, size_dn}. For example, if a tensor has a shape {2, 1, 2}, it means the tensor’s first dimension has size 2, second dimension has size 1, and third dimension has size 2:

Nx.tensor([[[1, 2]], [[3, 4]]])
#Nx.Tensor<
  s64[2][1][2]
  [
    [
      [1, 2]
    ],
    [
      [3, 4]
    ]
  ]
>

We can also describe the number of dimensions in a tensor as its rank. As you start to work more in the scientific computing space, you’ll inevitably come across descriptions of shape which reference 0-D shapes as scalars:

Nx.tensor(1)
#Nx.Tensor<
  s64
  1
>

1-D shapes as vectors:

Nx.tensor([1, 2, 3])
#Nx.Tensor<
  s64[3]
  [1, 2, 3]
>

2-D shapes as matrices:

Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

A tensor’s shape tells us 2 things:

  • How to traverse and index a tensor
  • How to perform shape-dependent operations

Theoretically, we could write all of our operations to work on a flat binary, but that doesn’t map very well to the real-world. We reason about things with dimensionality. Let’s consider the example of an image. A common representation of images in numerical computing is {color_channels, height, width}. A 32x32 RGB image will have shape {3, 32, 32}. Now imagine if you were asked to access the green value of the pixel at height 5 and width 17. If you have no understanding of the tensor’s shape, this would be an impossible task. However, since you do know the shape, you just need to perform a few calculations and you’ll be able to very efficiently access any value in the tensor.

To access a tensor’s shape, you can use Nx.shape:

Nx.shape(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
{2, 3}

To access its rank, you can use Nx.rank:

Nx.rank(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
2

Tensors have names

As a consequence of working in multiple dimensions, you often want to perform operations only on certain dimensions of an input tensor. Some Nx functions give you the option to specify an axis or axes to reduce, permute, traverse, slice, etc. The norm is to access axes by their index in a tensor’s shape. For example, axis 1 in shape {2, 6, 3} is of size 6.

Nx.names(Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]))
[:x, :y]

Tensors have a type

As mentioned before, a consequence of operating on binaries is the need to have tensors with homogenous types. In other words, every value in the tensor must be the same type. This is important for efficiency, which is why tensors exist - to support efficient, parallel computation. If we know that every value in a 1-D tensor is 16 bits long in memory and that the tensor is 128 bits long, we can quickly calculate that there are 8 values in it—128 / 16 = 8. We can also easily grab individual values for parallel calculation because we know that there’s a new value every 16 bits. Imagine if this weren’t the case; that is, if the first value were 8 bits long, the second value 32 bits, and so on. To count the items or divide them into groups, we’d have to walk through the entire tensor every time (a waste of time), and each value would have to declare its length (a waste of space). All tensors are instantiated with a datatype which describes their type and size. The type is represented as a tuple of {:type, size}.

Valid types are:

  • :f - floating point types
  • :s - signed integer types
  • :u - unsigned integer types
  • :bf - brain-floating point types

Valid sizes are:

  • 8, 16, 32, 64 for signed and unsigned integer types
  • 16, 32, 64 for floating point types
  • 16 for brain floating point types

The size of the type more accurately describes its precision. While 64-bit types consume more memory and are slower to operate on, they are more precise than their 32-bit counterparts. The default integer type in Nx is {:s, 64}. The default float type is {:f, 32}. When creating tensors with values that are mixed, Nx will promote the values to the “highest” type, preferring to (for example) waste some space by representing a 16-bit float in 32 bits than to lose some of the information in a 32-bit float by chopping it to 16 bits. This is called type promotion. Type promotion is outside the scope of this post, but it’s something to be aware of.

You can get the type of a tensor with Nx.type:

Nx.type(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
{:s, 64}
Nx.type(Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))
{:f, 32}

Creating Tensors

Nx.tensor([[1, 2, 3]])
#Nx.Tensor<
  s64[1][3]
  [
    [1, 2, 3]
  ]
>

You can also specify the :type and :names of the tensor:

Nx.tensor([[1, 2, 3], [4, 5, 6]], type: {:f, 64}, names: [:x, :y])
#Nx.Tensor<
  f64[x: 2][y: 3]
  [
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0]
  ]
>