Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Convolution with Nx

new_conv_fft.livemd

Convolution with Nx

Mix.install(
  [
    {:nx, "~> 0.9.2"},
    {:scidata, "~> 0.1.11"},
    {:exla, "~> 0.9.2"},
    {:kino, "~> 0.14.2"},
    {:stb_image, "~> 0.6.9"},
    {:axon, "~> 0.7.0"},
    {:kino_vega_lite, "~> 0.1.11"},
    {:nx_signal, "~> 0.2.0"},
    # {:emlx, github: "elixir-nx/emlx", ovveride: true}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
  #config: [nx: [default_backend: {EMLX.Backend, device: :gpu}]]
)

# Nx.Defn.default_options(compiler: EMLX)

What is a convolution?

Formally, the convolution product is a mathematical operation that can be understood as a sum of “sliding products” between two tensors.

This operation is widely used in signal processing and has applications in various fields such as image processing, audio filtering, and machine learning.

We will use Nx.conv in the context of convolution layers in a model. This means that the coefficients of the kernel is learnt.

In the case we impose the kernel, then this operation is relevant only if the kernel is real symmetic.

How to use Nx.conv

You have the following parameters:

  • padding
  • shape
  • stride
  • permutation

The shape is expected to be: {batch_size, channel, dim_1, ...}

The stride parameter defines how the kernel slides along the axes. A stride of 1 means the kernel moves one step at a time, while larger strides result in skipping over positions.

The batch_size parameter is crucial when working with colored images, as it allows you to convolve each color channel separately while maintaining the output’s color information.

In contrast, for a convolution layer in a neural network, you might average all color channels and work with a grayscale image instead. This reduces computational complexity while still capturing essential features.

Example with a vector

If t = Nx.shape(t)= {n}, you can simply use t as a input in a convolution by doing Nx.reshape({1,1,n}). This respects the pattern:

batch_size=1, channel=1, dim_1=n.

Example with an image

For example, you have a tensor that represents an RGBA image (coloured). Its shape is typically:

{h,w,c} = {256,256,4}.

You may want to set 4 for the batch_size, to batch per colour layer. We will use a permutation on the tensor. How?

  • you add a new dimension: Nx.new_axis(t, 0). The new shape of the tensor is {1, 256,256,4}.
  • you set the permutation_input to permute the tensor with [3,0,1,2] because the second dimension must be the channel. You understand why you had to add a “fake” dimension to respect the pattern:

batch_size=4, channels=1, dim_1 = 256, dim_2=256

Padding a tensor

You will notice that we “pad” the tensors. This means that we add “zeros” to the signal function in order to build an input that is zero outisde of the number of points ($N$ here). This is done below with padding: :same or in the polynomial product or in the image filtering.

Some examples on how to use the function Nx.pad with a 1D tensor to add numbers to a tensor.

t = Nx.tensor([1,2,3])
one_zero_left = Nx.pad(t, 0, [{1,0,0}])
two_5_right = Nx.pad(t, 5, [{0,2,0}])

[t, one_zero_left, two_5_right] 
|> Enum.map(&Nx.to_list/1)
[[1, 2, 3], [0, 1, 2, 3], [1, 2, 3, 5, 5]]

With a 2D-tensor, you need to add an extra dimension. For example, take:

m = Nx.tensor([[1,2],[3,4]])
#Nx.Tensor<
  s32[2][2]
  EXLA.Backend
  [
    [1, 2],
    [3, 4]
  ]
>

We will successfully:

  • add zero-padding on the first and last row
  • add zero-padding on the first column and last column
  • “suround” the tensor with zeroes
{Nx.pad(m, 0, [{1,1,0}, {0,0,0}]), Nx.pad(m, 0, [{0,0,0}, {1,1,0}]), 
  Nx.pad(m, 0, [{1,1,0}, {1,1,0}])}
{#Nx.Tensor<
   s32[4][2]
   EXLA.Backend
   [
     [0, 0],
     [1, 2],
     [3, 4],
     [0, 0]
   ]
 >,
 #Nx.Tensor<
   s32[2][4]
   EXLA.Backend
   [
     [0, 1, 2, 0],
     [0, 3, 4, 0]
   ]
 >,
 #Nx.Tensor<
   s32[4][4]
   EXLA.Backend
   [
     [0, 0, 0, 0],
     [0, 1, 2, 0],
     [0, 3, 4, 0],
     [0, 0, 0, 0]
   ]
 >}

Example: edge detector with a discret differential kernel

The MNIST dataset contains images of numbers from 0 to 255 organized row wise with a format 28x28 pixels with 1 channel (grey levels).

This means that the pixel $(i,j)$ is at the $i$-th row and $ j$-th column.

{{images_binary, type, images_shape} = _train_images, train_labels} = Scidata.MNIST.download_test()
{labels, {:u,8}, label_shape} = train_labels

images = images_binary |> Nx.from_binary(type) |> Nx.reshape(images_shape)

{{_nb, _channel, _height, _width}, _type}  = {images_shape, type}
{{10000, 1, 28, 28}, {:u, 8}}

The labels describe what the image is. The 15th image is a “5”.

Nx.from_binary(labels, :u8)[15] |> Nx.to_number()
5

To be sure, we display this image with StbImage.

> We transform the tensor as StbImage expects the format {height, width, channel} whilst the MNIST uses {channel, height, width}.

five = images[15] 

Nx.transpose(five, axes: [1,2,0])
|> StbImage.from_nx()
|> StbImage.resize(500, 500)
%StbImage{
  data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
  shape: {500, 500, 1},
  type: {:u, 8}
}

We can detect edges by performing a discrete differentiation, where the kernel computes the difference between adjacent pixel values. This can be viewed as a discrete approximation of a derivative.

Since the image is slightly blurred, we use a stride of 2, allowing us to skip over intermediate pixels. The kernel is defined as $[-1,0,1]$, which calculates the difference between pixel values separated by two units.

This operation can be applied along the width axis to detect vertical edges or along the height axis to detect horizontal edges.

Given pixel coordinates $(i,j)$ where $i$ is the row and $j$ the column, we compute:

  • vertical edges with horizontally adjacent pixels : $\mathrm{pix}(i, j) \cdot (-1) + \mathrm{pix}(i, j+2) \cdot (1)$ and the kernel will be reshaped into $ {1,1,1,3}$.

  • horizontal edges with vertically adjacent pixels: $\mathrm{pix}(i, j) \cdot (-1) + \mathrm{pix}(i+2, j) \cdot (1)$ and the kernel will be reshaped into $ {1,1,3, 1}$

Below, we compute two transformed images using these vertical and horizontal edge detectors and display the results.

edge_kernel = Nx.tensor([-1, 0, 1]) 

vertical_detection = Nx.conv(
  Nx.reshape(five, {1, 1, 28, 28}) |> Nx.as_type(:f32), 
  Nx.reshape(edge_kernel, {1, 1, 1, 3}), 
  padding: :same
)

horizontal_detection = Nx.conv(
  Nx.reshape(five, {1, 1, 28, 28}) |> Nx.as_type(:f32), 
   Nx.reshape(edge_kernel, {1, 1, 3, 1}), 
  padding: :same
)

Nx.concatenate([vertical_detection, horizontal_detection], axis: -1)
|> Nx.reshape({28, 56, 1}) 
|> Nx.as_type({:u, 8}) 
|> StbImage.from_nx()
|> StbImage.resize(500, 1000)
%StbImage{
  data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
  shape: {500, 1000, 1},
  type: {:u, 8}
}

Using an Axon model with a convolution layer to predict vertical edges

Let’s use Axon to recover the convolution kernel. Our model will consist of a single convolution layer.

We will give to our model a series of images and its transform (by the “known” kernel).

The first 15 images do not contain the number 5. We will apply the “vertical” convolution transform to these images and feed them to the model to help it “learn” how to detect vertical edges.

# we produce a sample of 14 tuples {input, output = conv(input, kernel)} 
# which contain some numbers (different from 5)
l = 14
kernel = Nx.tensor([-1.0,0.0,1.0])  |> Nx.reshape({1,1,1,3}) 

training_data = 
  images[0..l] 
  |> Nx.reshape({l+1, 1,28,28})
  |> Nx.as_type(:f32)
  |> Nx.to_batched(1)
  |> Stream.map(fn img -> 
    {img, Nx.conv(img, kernel, padding: :same)} 
  end)
#Stream<[
  enum: 0..14,
  funs: [#Function<50.105594673/1 in Stream.map/2>, #Function<50.105594673/1 in Stream.map/2>]
]>

The model is a simple convolution layer:

model = 
  Axon.input("x", shape: {nil, 1, 28, 28})
  |> Axon.conv(1, 
    channels: :first, 
    padding: :same, 
    kernel_initializer: :zeros, 
    use_bias: false, 
    kernel_size: 3
  )
#Axon<
  inputs: %{"x" => {nil, 1, 28, 28}}
  outputs: "conv_0"
  nodes: 2
>

We train the model with the dataset. Note that it is important to use compiler: EXLA for the computation speed.

optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2)
#optimizer = Polaris.Optimizers.adabelief(learning_rate: 1.0e-2)

params = 
  Axon.Loop.trainer(model, :mean_squared_error, optimizer)
  |> Axon.Loop.run(training_data, Axon.ModelState.empty(), epochs: 20,  compiler: EXLA)

16:41:17.507 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 0, loss: 0.0000000
Epoch: 1, Batch: 0, loss: 4095.9575195
Epoch: 2, Batch: 0, loss: 2760.0224609
Epoch: 3, Batch: 0, loss: 2029.3372803
Epoch: 4, Batch: 0, loss: 1624.5377197
Epoch: 5, Batch: 0, loss: 1374.6849365
Epoch: 6, Batch: 0, loss: 1202.7775879
Epoch: 7, Batch: 0, loss: 1075.4449463
Epoch: 8, Batch: 0, loss: 976.2529907
Epoch: 9, Batch: 0, loss: 895.9296265
Epoch: 10, Batch: 0, loss: 828.9421387
Epoch: 11, Batch: 0, loss: 771.8013916
Epoch: 12, Batch: 0, loss: 722.1952515
Epoch: 13, Batch: 0, loss: 678.5291138
Epoch: 14, Batch: 0, loss: 639.6677246
Epoch: 15, Batch: 0, loss: 604.7778931
Epoch: 16, Batch: 0, loss: 573.2318115
Epoch: 17, Batch: 0, loss: 544.5447998
Epoch: 18, Batch: 0, loss: 518.3340454
Epoch: 19, Batch: 0, loss: 494.2913818
#Axon.ModelState<
  Parameters: 9 (36 B)
  Trainable Parameters: 9 (36 B)
  Trainable State: 0, (0 B)
>

We can now check what our Axon model learnt. We compare:

  • its “vertical”-convolution
  • the predicted image by our model.
five =  Nx.reshape(images[15], {1,1,28,28}) |> Nx.as_type(:f32)

model_five = 
  Axon.predict(model, params, five)

conv_five = Nx.conv(
  Nx.reshape(images[15] , {1, 1, 28, 28}) |> Nx.as_type(:f32), 
  Nx.reshape(edge_kernel , {1, 1, 1, 3}), 
  padding: :same
)

Nx.concatenate([model_five, conv_five], axis: -1)
|> Nx.reshape({28, 56, 1}) 
|> Nx.clip(0,255)
|> Nx.as_type(:u8)
|> StbImage.from_nx()
|> StbImage.resize(500, 1500)
%StbImage{
  data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
  shape: {500, 1500, 1},
  type: {:u, 8}
}

Convolution live example

We aim to apply a blurring filter to our webcam feed.

This is achieved by convolving the image with a Gaussian kernel, which computes a weighted average (the “Gaussian mean”) of the pixels within the kernel’s area.

Most of the tensor manipulations used below are explained in this post:

The gaussian kernel

Let’s see how the “blur” kernel is built.

  • produce two vectors of integers x and y as [-2, -1, 0, 1, 2].
  • compute the gaussian coefficients for the 5x5=25 couples $(x_i,y_j)$ $$ \dfrac1{2\pi\sigma^2}\cdot \exp^{\big(\dfrac{-x_i^2+y_j^2}{2\sigma}\big)} $$
  • Normalize: Scale the coefficients so their sum equals 1. This ensures that the resulting kernel preserves the overall brightness of the image.

The convolution

We then perform a convolution between the input (a tensor representing an image) and the given kernel.

The input image tensor is expected to have the shape {height, width, channel} with a data type of :u8, representing integers ranging from 0 to 255. For example, a 256x256 color PNG image would have the shape {256, 256, 4}, where the 4 channels correspond to RGBA values.

To prepare the image for computation, the tensor is converted to floats in the range [0, 1] by dividing each element by 255.

To reduce the computational load, we process only every second pixel. This effectively downsamples the image from, for example, 256x256 to 128x128. This is achieved using strides [2, 2], which skip every other row and column.

As noted earlier, the convolution operation expects the input format {batch_size, channel, dim_1, dim_2, ...}.

To accommodate this, we batch the computations by color channel. First, we add an additional dimension to the tensor using Nx.axis_new, and then permute the axes to bring the channel dimension into the batch position. This transforms the tensor from {256, 256, 4} to {4, 1, 256, 256}. The permutation is specified using the input_permutation parameter.

Once the convolution is complete:

  • The extra singleton dimension is removed using Nx.squeeze.
  • The resulting floats are rescaled back to the range [0, 255].
  • The values are rounded to the nearest integer and cast back to the type :u8, as integers are required to produce a valid RGB image.
defmodule Filter do
  import Nx.Defn
  import Nx.Constants, only: [pi: 0]

  defn gaussian_blur_kernel(opts \\ []) do
    sigma = opts[:sigma]
    range = Nx.iota({5}) |> Nx.subtract(2)

    x = Nx.vectorize(range, :x)
    y = Nx.vectorize(range, :y)

    # Apply Gaussian function to each of the 5x5 elements
    kernel =
      Nx.exp(-(x * x + y * y) / (2 * sigma * sigma))
    kernel = kernel / (2 * pi() * sigma * sigma)
    kernel = Nx.devectorize(kernel)
    kernel / Nx.sum(kernel)
  end

  defn apply_kernel(image, kernel, opts \\ []) do
    # you work on half of the image, reducing from 256 to 128 with strides: [2,2]
    opts = keyword!(opts, strides: [2, 2])

    input_type = Nx.type(image)
    # use floats in [0..1] instead of u8 in [0..255]
    image = image / 255

    {m, n} = Nx.shape(kernel)
    kernel = Nx.reshape(kernel, {1,1,m,n})

    # the image has the shape {:h,:w, :c}
    # insert a new dimension ( Nx.new_axis does not need the current shape)
    # and swap {1, h, w, c} to {c, 1, h, w} for the calculations to be batched on each colour
    # then cast into :u8
    image
    |> Nx.new_axis(0)
    |> Nx.conv(kernel,
      padding: :same,
      input_permutation: [3, 0, 1, 2],
      output_permutation: [3, 0, 1, 2],
      strides: opts[:strides]
    )
    |> Nx.squeeze(axes: [0])
    |> Nx.multiply(255)
    |> Nx.clip(0, 255)
    |> Nx.as_type(input_type)
  end
end
{:module, Filter, <<70, 79, 82, 49, 0, 0, 20, ...>>, true}

We can take a look into a kernel:

Filter.gaussian_blur_kernel(sigma: 0.5)
#Nx.Tensor<
  f32[x: 5][y: 5]
  EXLA.Backend
  [
    [6.962478948935313e-8, 2.8088648832635954e-5, 2.075485826935619e-4, 2.8088648832635954e-5, 6.962478948935313e-8],
    [2.8088648832635954e-5, 0.0113317696377635, 0.08373107761144638, 0.0113317696377635, 2.8088648832635954e-5],
    [2.075485826935619e-4, 0.08373107761144638, 0.6186935901641846, 0.08373107761144638, 2.075485826935619e-4],
    [2.8088648832635954e-5, 0.0113317696377635, 0.08373107761144638, 0.0113317696377635, 2.8088648832635954e-5],
    [6.962478948935313e-8, 2.8088648832635954e-5, 2.075485826935619e-4, 2.8088648832635954e-5, 6.962478948935313e-8]
  ]
>

The code below in a custom Kino.JS.Live module. It captures the embedded webcam feed from your computer and produces a blurred version using the convolution module defined above.

The JavaScript code captures frames from the webcam into an OffscreenCanvas. Every other frame is processed with a convolution filter, and the resulting frame is sent to the Elixir backend as a binary. Kino.JS.Live simplifies this process using a WebSocket and the pushEvent mechanism that accepts binary data.

The binary output of the convolution is converted to JPEG (:jpg) format, offering better compression than PNG (:png)—typically achieving a compression ratio of around 1:4.

This compressed binary is then sent back to the JavaScript client via a WebSocket using the broadcast_event function on the server side and the handleEvent callback on the client side.

To restore the original resolution of 256x256 (after reducing it to 128x128 for processing), the canvas.drawImage function is used, specifying both the source and target dimensions.

defmodule Streamer do
  use Kino.JS
  use Kino.JS.Live

  def html() do
    """
    
    
    """
  end

  def start(), do: Kino.JS.Live.new(__MODULE__, html())

  @impl true
  def handle_connect(ctx), do: {:ok, ctx.assigns.html, ctx}

  @impl true
  def init(html, ctx), do: {:ok, assign(ctx, html: html)}

  # it receives a video frame in binary form
  @impl true
  def handle_event("new frame", {:binary,_, buffer}, ctx) do
    image = 
      Nx.from_binary(buffer, :u8) 
      # change shape from {1, 262144} to {256, 256, 4}
      |> Nx.reshape({256, 256, 4})

    # apply a {5,5} gaussian kernel
    kernel = Filter.gaussian_blur_kernel(sigma: 1)   
    output = 
      Filter.apply_kernel(image, kernel) 
      |> StbImage.from_nx()
      |> StbImage.to_binary(:jpg)
    
    broadcast_event(ctx, "processed_frame", {:binary, %{}, output})
    {:noreply, ctx}
  end

  asset "main.js" do
    """
    export async function init(ctx, html) {
      ctx.root.innerHTML = html;
    
      const height = 256, width = 256,
        video = document.getElementById("invid"),
        renderCanvas = document.getElementById("canvas"),
        renderCanvasCtx = renderCanvas.getContext("2d"),
        offscreen = new OffscreenCanvas(width, height),
        offscreenCtx = offscreen.getContext("2d", { willReadFrequently: true });

      try {
        const stream = await navigator.mediaDevices.getUserMedia({
          video: { width, height },
        });
        video.srcObject = stream;
        let frameCounter = 0;
        video.play();
    
        const processFrame = async (_now, _metadata) => {
          frameCounter++;
          // process only 1 frame out of 2
          if (frameCounter % 2 == 0) {
            offscreenCtx.drawImage(video, 0, 0, width, height);
            const {data}  = offscreenCtx.getImageData(0, 0, width, height);
            ctx.pushEvent("new frame", [{}, data.buffer]);
          }
          requestAnimationFrame(processFrame);
        };
    
        requestAnimationFrame(processFrame);
      } catch (err) {
        console.error("Webcam access error:", err);
      }

      ctx.handleEvent("processed_frame", async ([{}, binary]) => {
        const bitmap = await createImageBitmap(new Blob([binary], { type: "image/jpeg" }));
        renderCanvasCtx.drawImage(bitmap, 0, 0, bitmap.width, bitmap.height, 0,0, 256,256)
      });
    }
    """
  end
end
{:module, Streamer, <<70, 79, 82, 49, 0, 0, 17, ...>>, :ok}
Streamer.start()