Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Nx Tips

NxTips.livemd

Nx Tips

Mix.install([
  {:nx, "~> 0.5"},
  {:axon, "~> 0.5"},
  {:table_rex, "~> 3.1.1"},
  {:exla, "~> 0.5"},
])

Nx Tip of the Week #1 – Using transforms

Nx is a tensor manipulation or array programming library similar to NumPy, TensorFlow, or PyTorch.

Nx introduces a new type of function definition, defn, that is a subset of the Elixir programming language tailored specifically to numerical computations. When numerical definitions are invoked, they’re transformed into expressions (internally Nx.Defn.Expr) which represent the AST or computation graph of the numerical definition. These expressions are manipulated by compilers (like EXLA) to produce executables that run natively on accelerators.

In newer Nx versions, transform/2 has been removed, so its a better practice to let all math be handled by Nx and the business logic in the Elixir realm.

Use print_expr inside a defn expression to debug the operation.

defmodule NxTest do
  import Nx.Defn

  def cross_entropy_loss(y_true, y_pred) do
    check_shape(Nx.shape(y_true), Nx.shape(y_pred))
    nx_cross_entropy_loss(y_true, y_pred)
  end
  
  defp check_shape(same_shape, same_shape), do: nil
  defp check_shape(_y_true_shape, _y_pred_shape), 
    do: raise(ArgumentError, "shapes do not equal")
  
  defn tanh_power(a, b) do
    exp = Nx.tanh(a) + Nx.pow(b, 2)
    print_expr(exp)
  end

  defn nx_cross_entropy_loss(y_true, y_pred) do
    Nx.mean(Nx.log(y_true) * y_pred)
  end
end
a = Nx.tensor(0)
b = Nx.tensor(2)
NxTest.tanh_power(a, b)
# To validate that is a Tensor.
{is_struct(a, Nx.Tensor), is_struct(a, Nx)}
y_true = Nx.tensor([10, 10])
y_pred = Nx.tensor([100, 100])
NxTest.cross_entropy_loss(y_true, y_pred)
y_true = Nx.tensor([10, 10])
y_pred = Nx.tensor([100])
NxTest.cross_entropy_loss(y_true, y_pred)

Nx Tip of the Week #2 – Tensor Operations for Elixir Programmers

In Elixir, it’s common to manipulate data using the Enum module.

Element-wise unary functions

Nx contains a number of element-wise unary functions that are tensor aware.

# Enum way
a = [1, 2, 3]
Enum.map(a, fn x -> :math.exp(x) end)
# Nx way (some operation must specify the output type).
a = Nx.tensor([1, 2, 3], type: {:s, 32}, names: [:data])
Nx.map(a, [type: {:f, 32}], fn x -> Nx.exp(x) end)
# Nx.map is not compatible for most Nx compilers.
# Must of Nx functions are tensor aware, 
# therefore efficient for most Nx compilers
a = Nx.tensor([1, 2, 3], type: {:f, 32}, names: [:data])
Nx.exp(a)
a = Nx.iota({2, 2, 1, 2, 1, 2}, type: {:f, 32})
Nx.exp(a)

Element-wise binary functions

# With Enum
a = [1, 2, 3]
b = [4, 5, 6]
a |> Enum.zip(b) |> Enum.map(fn {x, y} -> x + y end)
# In Nx (there is no zip)

a = Nx.tensor([1, 2, 3], type: {:f, 32})
b = Nx.tensor([4, 5, 6], type: {:f, 32})

Nx.add(a, b)
a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], type: {:f, 32})
b = Nx.tensor([[[2, 3, 4], [5, 6, 7], [8, 9, 10]]], type: {:f, 32})
Nx.add(a, b)
broadcast_scalar = fn x, list -> Enum.map(list, & &1*x) end
broadcast_scalar.(5, [1, 2, 3])
broadcast_scalar = &Nx.multiply(&1, &2)
broadcast_scalar.(5, Nx.tensor([1, 2, 3], type: {:f, 32}))
Nx.multiply(5, Nx.tensor([1, 2, 3], type: {:f, 32}))

Aggregate Operators

# With Enum
a = [1, 2, 3]
Enum.reduce(a, 0, fn x, acc -> x + acc end)
# Multi-dimensional
a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

Enum.reduce(a, 0, fn x, acc ->
  Enum.reduce(x, 0, fn y, inner_acc ->
    y + inner_acc
  end) + acc
end)
a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

Enum.reduce(a, [], fn x, acc ->
  [
    Enum.reduce(x, 0, fn y, inner_acc ->
      y + inner_acc
    end)
    | acc
  ]
end)
|> Enum.reverse()
# With Nx, there is Nx.reduce, however, similar to Nx.map, avoid using it, 
# use native Nx implementations.

a = Nx.tensor([1, 2, 3], type: {:f, 32})

Nx.reduce(a, 0, fn x, acc -> Nx.add(x, acc) end)
# Multiple dimensions
a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
Nx.sum(a)
# Through a specific axis
a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
Nx.sum(a, axes: [1])
Nx.sum(a, axes: [0])
# Multiple Axis.
Nx.sum(a, axes: [0, 1])
# Named axis
a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:x, :y, :z], type: {:f, 32})
Nx.sum(a, axes: [:x])
Nx.sum(a, axes: [:y])
Nx.sum(a, axes: [:z])
Nx.sum(a, axes: [:z, :y]) == Nx.sum(a, axes: [:y, :z])

Notice how in each case the axis that disappears is the one provided in axes. axes also supports negative indexing; however, you should generally prefer using named axes over integer axes wherever possible.

Nx Tip of the Week #3 – Many Ways to Create Arrays (Tensors)

In Nx, the fundamental type is the Tensor. You can think of a tensor as a multi-dimensional array, like the numpy.ndarray.

# this is the most tempted to use, but you should usually try to avoid.
# 0D Tensor
Nx.tensor(1)
#1D Tensor
Nx.tensor([1.0, 2.0, 3.0])
#ND Tensor
Nx.tensor([[[[[[[[[[1,2]]]]]]]]]])
# By default, Nx.Tensor creates tensor with type s64
# when the inputs are all integer types and f32 when the inputs are all float types.
Nx.tensor([1.0, 2])
# You can also specify the input type and dimension names.
Nx.tensor([1, 2, 3], type: {:bf, 16}, names: [:data])

# Also you can define its backend, but you must added to the dependencies.
#Nx.tensor([1, 2, 3], backend: Torchx.Backend)

Using Nx.tensor/2 is convenient, but is generally less efficient than other methods. Nx.Tensor generally represents tensor data as binaries, so Nx.tensor/2 needs to iterate through the entire list and rewrite it to a binary. You should avoid this, if possible.

From Binaries

Instead of creating tensors from lists, you should try to create tensors from binaries.

Tensor data is generally stored in a binary. Using native manipulation is usually more efficient than with other data types.

Nx.from_binary(<<0::64-signed-native>>, {:s, 64})
Nx.from_binary(<<0::32-float-native>>, {:f, 32})

# Notice Nx.from_binary/2 requires the input type and infers the shape as a flat list. 

# Note: You’ll likely want to brush up on binary pattern matching, 
# creation, and manipulation as you work with Nx.
Nx.from_binary(<<1::64-float-native>>, {:f, 64})
Nx.from_binary(<<1::64-float-native>>, {:f, 32})
Nx.from_binary(<<1::64-float-native>>, {:s, 64})
Nx.from_binary(<<1::64-float-native>>, {:u, 8})
t = Nx.from_binary(<<1::64-float-native>>, {:f, 64})
Nx.reshape(t, {1, 1, 1, 1})

# Nx.reshape/2 is actually just a meta operation. 
# The implementation doesn’t move the underlying bytes at all, 
# it just changes the shape property of the input tensor.

# You just need to know the shape of the input data.

Broadcasting

# np.full, np.full_like -> Nx.broadcast
 zeros = Nx.broadcast(0, {2, 5})
ones_like_zeros = Nx.broadcast(1, zeros)
# If you want to dictate the output type

Nx.broadcast(Nx.tensor(0, type: {:bf, 16}), {2, 2})

Counting Up

# Nx.iota/2. Nx.iota/2 is like np.arange

Nx.iota({2, 5}, axis: 1)
Nx.iota({2, 5})
Nx.iota({1}, type: {:bf, 16}, names: [:data])
Nx.multiply(Nx.iota({2, 5}, axis: 1), 3)
# Eye matrix
Nx.equal(Nx.iota({3, 3}, axis: 0), Nx.iota({3, 3}, axis: 1))

Random Numbers

key = Nx.Random.key(103)

{uniform, _new_key} = Nx.Random.uniform(key, shape: {2, 2}, type: {:f, 32})
{normal, _new_key} = Nx.Random.normal(key, 0, 5, shape: {2,2})
# create a random mask

probability = 0.5
{uniform, _new_key} = Nx.Random.uniform(key, shape: {5, 5})
Nx.select(Nx.less_equal(uniform, probability), 0, 1)

Template

Nx also has a template creation method that defines a template for an expected future value. This is useful for things like ahead-of-time compilation.

# You can't perform any operation on this tensor. 
# It exists exclusively to define APIs that say a tensor 
# with a certain type, shape, and names is expected in the future.
t = Nx.template({4, 4, 4}, {:f, 32}, names: [:x, :y, :z])

Nx Tip of the Week #4 – Using Keywords

Numerical definitions can only accept tensors or numbers as positional arguments; however, you can get around this inflexibility using keyword lists. You can pass and use optional keyword arguments in your numerical definitions with the keyword! method.

Parameter Initializers

biases = Nx.broadcast(0.0, {2, 2})
defmodule NxKeywordTest do
  import Nx.Defn

  defn zeros(opts \\ []) do
    # This can extract the arguments.
    opts = keyword!(opts, [:shape, type: {:f, 64}])
    Nx.broadcast(Nx.tensor(0, type: opts[:type]), opts[:shape])
  end

  defn trigonomtric_fn(x, opts \\ []) do
    opts = keyword!(opts, mode: :sin)

    exp =
      case opts[:mode] do
        :sin -> Nx.sin(x)
        :cos -> Nx.cos(x)
        :tan -> Nx.tan(x)
      end

    print_expr(exp)
  end

  defn try_to_add_one(x, opts \\ []) do
    opts = keyword!(opts, add_one?: true)

    exp = if opts[:add_one?], do: Nx.add(x, 1), else: x
    print_expr(exp)
  end

  defn my_function(opts \\ []) do
    opts = keyword!(opts, [:value])
    opts[:value]
  end
end
NxKeywordTest.zeros(shape: {2, 2})
NxKeywordTest.zeros(shape: {2, 2}, type: {:bf, 16})

Passing Flags

NxKeywordTest.trigonomtric_fn(0, mode: :sin)
NxKeywordTest.trigonomtric_fn(0, mode: :cos)
NxKeywordTest.trigonomtric_fn(0, mode: :tan)
NxKeywordTest.try_to_add_one(102, add_one?: true)
NxKeywordTest.try_to_add_one(1, add_one?: false)

Performance Considerations

Numerical definitions are JIT compiled and cached based on argument shapes (at least with the EXLA compiler) to avoid unnecessary recompilations. Compilation can be expensive, so you’d like to reuse compiled computations as much as possible.

When using keywords, if you have a value that’s constantly changing, you will force recompilation with the new value every time.

NxKeywordTest.my_function(value: 0.3)

Nx Tip of the Week #5 – Named Tensors

One of my biggest frustrations when working with NumPy and TensorFlow comes when working with axes. Named tensors fixed.

t1 = Nx.tensor([[1, 2, 3]], names: [:batch, :data])
t2 = Nx.tensor([[0, 1, 2]], names: [:batch, :data])
mse = fn x, y -> Nx.mean(Nx.subtract(x, y) |> Nx.pow(2), axes: [:data]) end
mse.(t1, t2)
defmodule NxNamedTest do
  import Nx.Defn

  defn flip_left_right(x) do
    x |> Nx.reverse(axes: [:width])
  end

  defn channels_first(x) do
    x |> Nx.transpose(axes: [:batch, :channels, :height, :width])
  end

  defn channels_last(x) do
    x |> Nx.transpose(axes: [:batch, :height, :width, :channels])
  end
end

Broadcasting

t1 = Nx.tensor([[1, 2, 3]], names: [:batch, :data])
t2 = Nx.tensor([[1, 2, 3]], names: [:x, :y])
# Since the axes (names) doesn't match the opertation will fail.
Nx.add(t1, t2)
# In order to perform binary operations on named tensors, the names must align.
# To execute a binary operation the name must either match,
# or one of the names must be nil (considered a wild card name).

t1 = Nx.tensor([[1, 2, 3]], names: [:batch, nil])
t2 = Nx.tensor([[1, 2, 3]], names: [nil, :data])

dbg({t1, t2})

Nx.add(t1, t2)

# Most operations have specific name rules that validate the operation
# can be performed correctly on the given tensors.

Nx Tip of the Week #6 – Compiler or Backend?

If performance matters, benchmark and decide. If you need flexibility or want to prototype quickly and not sacrifice speed, backends are a good choice.

If you need AOT (Ahead of Time) compilation or your programs are very computationally intensive, compilers work better. Library writers should always write library functions in defn to leave this choice to the user.

The power of Nx is in it’s flexibility. While the library contains pure Elixir implementations of every function in the main API, Nx is designed to integrate with compilers and backends with highly-optimized, native tensor manipulation routines.

Implementations of the Nx API in the Nx module are actually meta-implementations that dispatch to third-party implementations of the same function at runtime. These meta-implementations act like contracts for functions in the Nx API.

So, you can arbitrarily switch backends and somewhat guarantee your Nx implementations will remain unchanged.

The defn adds another layer of indirection to the program execution. While calls to Nx outside of defn dispatch to backends which evaluate the result of the computation, calls to Nx within defn dispatch to an Expr backend.

One thing you’ll notice when working with numerical definitions is that the first invocation of a defn is much slower than subsequent calls.

That’s because the first invocation needs to do the work of compiling the program – depending on the size of the computation, this may require lots of time relative to program execution.

Subsequent calls are cached; however, if you change the shape or type of your inputs, the numerical definition needs to compile a new version of your program specific to the input shape and type. If your input shapes or types are constantly changing, your bottleneck will almost certainly be compilation time.

An additional pitfall of the compiled approach is the strictness of the syntax within defn. All inputs must be tensors (unless using keywords), you must match on tuple inputs, and you can use a limited subset of Elixir.

You can only use grad from within defn.

Backends like Torchx offer flexibility, and facilitate rapid prototyping, with pleasing performance gains over pure Elixir implementations. Compilers like EXLA unlock even more potential performance and memory optimizations, and open the door to things like ahead-of-time compilation. Generally, if you need the flexibility of intermixing Nx code with your regular Elixir code, using a backend is probably the more convenient option. Additionally, some Nx programs are so small or are such that they won’t benefit from performance gains from compilers.

When dealing with compute intensive, purely numerical programs, compilers are usually the better option.

Axon: Deep Learning in Elixir

Axon consists of the following components:

  • Functional API – A low-level API of Elixir defn of which all other APIs build on.

  • Model Creation API – A high-level model creation API which manages model initialization and application.

  • Optimization API – An API for creating and using first-order optimization techniques.

  • Training API – An API for quickly training models.

Functional API

At the lowest-level, Axon consists of a number of modules with functional implementations of common methods in deep learning:

  • Axon.Activations – Element-wise activation functions.
  • Axon.Initializers – Model parameter initialization functions.
  • Axon.Layers – Common deep learning layer implementations.
  • Axon.Losses – Common loss functions.
  • Axon.Metrics – Training metrics such as accuracy, absolute error, precision, etc.

Model Creation

The goal of the model creation API is to eliminate most of the boilerplate associated with creating, initializing, and applying models. The Axon struct represents a model’s computation graph using the following struct:

defstruct [:inputs, :outputs, :nodes]
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

dbg(model)

Axon.Display.as_table(model, Nx.template({1, 784}, :f32)) |> IO.puts
# you can use regular Elixir functions to represent model building blocks, and compose them.
defmodule MyModel do
  def residual(x, units) do
    x
    |> Axon.dense(units, activation: :relu)
    |> Axon.add(x)
  end
   
  def model() do
    Axon.input("input", shape: {nil, 784})
    |> Axon.dense(128, activation: :relu)
    |> residual(128)
    |> Axon.dense(10, activation: :softmax)
  end
end
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)

params = init_fn.(Nx.template({1, 784}, :f32), %{})

Model Optimization

Axon’s model optimization API takes the same approach as that taken in DeepMind’s Optax. The goal of the API is to provide low-level constructs for creating advanced optimizers, and then to provide high-level optimizers built on top of that API.

Axon considers optimizers as the tuple: {init_fn/1, update_fn/2}. The init_fn/1 function accepts a model’s parameters and initializes the optimizer’s state. The update_fn/2 function accepts “updates” (most commonly gradients in gradient-based optimization) and an optimizer state and returns transformed updates and a new optimizer state.

Training

The purpose of the training API is to provide conveniences and common routines for implementing training loops. Autograd simplifies the task of writing a machine learning algorithm down to the task of writing a differentiable objective function.

If you can write a parameterized, differentiable objective function, and pair that with data, you can make use of Axon’s training API.

Nx Tip of the Week #7 – Using Nx.Defn.jit

There are actually 2 ways in Nx to accelerate your numerical definitions: invoking calls to defn with a @defn_compiler attribute set, or calling Nx.Defn.jit/3.

Using Nx.Defn.jit/3 is very useful when integrating numerical definitions with regular Elixir code.

defmodule JIT do
  import Nx.Defn
 
  defn softmax(x) do
    max_val = Nx.reduce_max(x)
    Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
  end
end

key = Nx.Random.key(103)
{uniform_1, new_key} = Nx.Random.uniform(key, shape: {5})
{uniform_2, _new_key} = Nx.Random.uniform(new_key, shape: {5})

jit_fn = Nx.Defn.jit(&amp;JIT.softmax/1, compiler: EXLA)
 
IO.inspect jit_fn.(uniform_1)
IO.inspect JIT.softmax(uniform_2)

Nx Tip of the Week #8 – Using Nx.Defn.aot/3

Nx.Defn.aot/3 was removed it.

Ahead-of-time compilation allows you to compile your numerical definitions into compact NIFs and execute them without needing the entire EXLA compiler and runtime.

Nx.Defn.aot/3 defines a module to interact with the NIF for you, so most of the work is out of your hands.

defmodule MyDefn do
  import Nx.Defn
  defn softmax(x) do
    max_val = Nx.reduce_max(x)
    Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
  end
end
Nx.Defn.aot(MyModule, [{:softmax, &amp;MyDefn.softmax/1, [Nx.template({100}, {:f, 32})]}], compiler: EXLA)

There are, of course, tradeoffs to using AOT compiled functions. First, you need to know the type and shape of your inputs ahead-of-time. Depending on your needs, this can be a pretty serious limitation. Second, AOT compilation is only supported for CPUs, so you can’t take advantage of GPU or TPU acceleration.

Overall, AOT compilation can be an excellent choice for deploying a model – especially at the edge. You can experiment on a more powerful machine before exporting the smaller compiled module to an edge device. I hope this gives you the tools you need to experiment with AOT compilation.

Nx Tip of the Week #9 – Window Functions

Note: You should just know that the body of a while loop will always be executed sequentially and will roundtrip data back to the CPU even when running on GPU.

defmodule NxWindowsTest do
  def cumsum(tensor, opts \\ []) do
    axis = opts[:axis]

    {padding_config, strides, window_shape} = create_windows(tensor, axis)

    Nx.window_sum(tensor, window_shape, strides: strides, padding: padding_config)
  end

  def create_windows(tensor, axis) do
    n = elem(Nx.shape(tensor), axis)

    padding_config =
      for i <- 0..(Nx.rank(tensor) - 1) do
        if i == axis, do: {n - 1, 0}, else: {0, 0}
      end

    strides = List.duplicate(1, Nx.rank(tensor))

    window_shape =
      List.duplicate(1, Nx.rank(tensor))
      |> List.to_tuple()
      |> put_elem(axis, n)

    {padding_config, strides, window_shape}
  end
end
revenue_by_month = Nx.tensor([[100, 200, 100, 50, 150]])
NxWindowsTest.cumsum(revenue_by_month, axis: 1)
tensor = Nx.tensor([
  [1, 2, 3, 4],
  [5, 6, 7, 8],
  [9, 10, 11, 12],
  [13, 14, 15, 16]
])

dbg(tensor)

# Aplicamos una suma de ventana de 2x2
window_suma = Nx.window_max(tensor, {3, 3}, strides: [2, 1])

Nx Tip of the Week #10 – Using Nx.select

You can solve a problem that you think requires boolean indexing with Nx.select. Nx.select builds a tensor from 3 tensors:

  • A mask
  • A true tensor
  • A false tensor

Nx.select will choose values from the true tensor when corresponding values in the mask are true and values from the false tensor when they are not.

a = Nx.tensor([[-1, 2, -3], [4, -5, 6]])
greater_than_a = Nx.greater(a, 0)
dbg({a, greater_than_a})
non_negative_a = Nx.select(greater_than_a, a, 0)

Nx Tip of the Week #11 – While Loops

Note: try to avoid using it.

Nx has a while construct which is implemented as an Elixir macro. The while construct takes an initial state, a condition, and a body which returns a shape which is the same as the initial state. It’s essentially a reduce_while, which aggregates state while some condition is satisfied.

defmodule NxWhileTest do
  import Nx.Defn

  defn count_to_ten() do
    # is like a C for, while (initial_condition, loop_condition, code_to_be_excuted)
    while(i = 0, i < 10, do: i + 1)
  end

  defn count_to_ten_twice() do
    while({i = 0, j = 0}, i < 10, do: {i + 1, j + 1})
  end

  # The shape of the body of the while loop must match 
  # the shape of the initial condition.

  defn build_a_vector() do
    # Create a "filler" tensor
    initial_tensor = Nx.broadcast(0.0, {12})
    key = Nx.Random.key(103)

    {_, final_tensor, _key} =
      while {i = 0, initial_tensor, key}, i < 12 do
        {val, new_key} = Nx.Random.uniform(key, shape: {1})
        # Update filler tensor "in-place"
        {i + 1, Nx.put_slice(initial_tensor, [i], val), new_key}
      end

    final_tensor
  end
end
NxWhileTest.count_to_ten()
NxWhileTest.count_to_ten_twice()
NxWhileTest.build_a_vector()

# The native defn loops you implement will be much more efficient than pure 
# Elixir loops performing the same operations, 
# but they’ll still be slow relative to other computations. 
# Because they can be difficult for JIT compilers

Nx Tip of the Week #12 – Nx.to_heatmap

Sometimes you want to quickly visualize the contents of a tensor. A quick way to visualize images across a single color channel is with Nx.to_heatmap.

When inspecting the result of Nx.to_heatmap, you’ll get a nice console representation of a heatmap printed out. This is especially useful when you’re quickly debugging and don’t want to bring in any additional dependencies such as VegaLite.

tensor = Nx.tensor([
  [1, 2, 3],
  [4, 5, 6],
  [7, 8, 9]
])

heatmap = Nx.to_heatmap(tensor)

Nx Tip of the Week #13 – Hooks

Part of the restrictiveness of defn is the inability to debug in the same way you would debug a normal Elixir function.

defmodule NxHooksTest do
  import Nx.Defn
  require Logger

  defn my_impossible_fn(x) do
    x |> Nx.exp() |> IO.inspect()
  end

  # Like IO.inspect, print_value will return the value 
  # it’s passed and inspect the input contents to the console. 
  # Note: inspect_value has been removed.
  defn my_possible_fn(x) do
    x |> Nx.exp() |> print_value()
  end

  # Nx.Defn.Kernel.hook allows you to perform side-effecting operations within defn. 
  defn my_hooked_function(x) do
    x |> Nx.exp() |> hook(&amp;Logger.info/1)
  end

  defn my_defined_hooked_function(x) do
    x |> Nx.exp() |> hook(:my_hook)
  end
end
tensor = Nx.tensor([1])
NxHooksTest.my_impossible_fn(tensor)
tensor = Nx.tensor([1])
NxHooksTest.my_possible_fn(tensor)
NxHooksTest.my_hooked_function(tensor)
hooks = %{my_hook: &amp;IO.inspect/1}
jit_fn = Nx.Defn.jit(&amp;NxHooksTest.my_defined_hooked_function/1, hooks: hooks)
# Hooked values must be used by a value returned by defn, 
# otherwise the hook will never execute.
# Hooks are implemented on top of the Nx.Stream API.
jit_fn.(tensor)

Nx Tip of the Week #14 – Slicing and Indexing

Nx offers a few different slicing and indexing routines which allow you to accomplish most of what you would want to do. Slicing can be a bit tricky given static shape requirements, but you usually can work around limitations.

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
dbg(a) 
Nx.slice(a, [0, 1], [1, 2])
tensor = Nx.tensor([
  [1, 2, 3, 4],
  [5, 6, 7, 8],
  [9, 10, 11, 12],
  [13, 14, 15, 16]
])
sub_tensor = Nx.slice(tensor, [0, 0], [2, 2])
# You may specify dynamic or tensor values for the start index
# as long as they have a scalar shape:
Nx.slice(a, [Nx.tensor(0), Nx.tensor(1)], [1, 2])
# The second list is the length of each slice. Each value must be known or static at compile-time. 
# This is because the length dictates the final shape of the sliced tensor

dbg(a)
Nx.slice(a, [1, 2], [2, 3])
# You can also make use of the Access syntax which builds on top of normal slice operations.
a[[1, 0..1]]
# you’re accessing the first index from the first dimension 
# and the zero to first indices from the second dimension.
# In newer versions of Elixir, you can slice an entire axis with '..'
a[[.., 2]]
# this tells that you want every element from the first dimension, 
# but only the 2 of the second dimension
# this gets all elements for the 2 dimensions.
a[[..,..]]
# If Nx.slice/4 and the Access syntax is not flexible enough for you, 
# you can try one of Nx.take/3, Nx.gather/3, and Nx.take_along_axis/3. 

tokens = Nx.tensor([127, 32, 0, 1, 5, 6])
key = Nx.Random.key(103)

{weights, _new_key} = Nx.Random.uniform(key, shape: {128, 32})

Nx.take(weights, tokens)
a = Nx.tensor([[4, 3, 0, 1, 6, 8], [5, 1, 2, 3, 6, 9]])
indices = Nx.argsort(a, axis: 1, stable: true)
dbg({a, indices})
Nx.take_along_axis(a, indices, axis: 1)
a = Nx.tensor([5, 1, 2, 3, 6, 9])
indices = Nx.argsort(a, direction: :desc)
# Nx.gather/3 contains a tensor of indices, 
# where the last axes in the indices tensor represents a single value in the source tensor
t = Nx.tensor([[1, 2], [3, 4]])
Nx.gather(t, Nx.tensor([[[1, 1], [0, 1], [1, 0], [1, 0]], [[1, 1], [0, 1], [1, 0], [1, 0]]]))