Convolution with Nx
Mix.install(
[
{:nx, "~> 0.9.2"},
{:scidata, "~> 0.1.11"},
{:exla, "~> 0.9.2"},
{:kino, "~> 0.14.2"},
{:stb_image, "~> 0.6.9"},
{:axon, "~> 0.7.0"},
{:kino_vega_lite, "~> 0.1.11"},
{:nx_signal, "~> 0.2.0"},
{:emlx, github: "elixir-nx/emlx", ovveride: true}
],
#config: [nx: [default_backend: EXLA.Backend]]
config: [nx: [default_backend: {EMLX.Backend, device: :gpu}]]
)
Nx.Defn.default_options(compiler: EMLX)
What is a convolution?
We will use Nx.conv
in the context of convolution layers in a model.
Formally, the convolution product is a mathematical operation that can be understood as a sum of “sliding products” between two tensors.
This operation is widely used in signal processing and has applications in various fields such as image processing, audio filtering, and machine learning.
How do you compute it?
There is a formula that sums over the number of samples or points of the signal we consider.
$$(s\circledast f)[n] = \sum_{i=-N}^{N} s[i-n]f[i]$$
Link with the Discrete Fourier Ttansform
Consider a signal $s(t)$ that we want to transform using a filtering signal $f(t)$.
The convolution product $(s \star f)(t)$ has a remarkable property when viewed in the frequency domain.
Denote as $\hat{s}(\omega)$ and $\hat{f}(\omega)$ the Discrete Fourier Transform of $s$ and $f$ respectively (the DFT is implemented as a Fast Fourier Transform). We have a remarkable relationship: a multiplication (term by term) in the frequency domain is a DFT of a (circular) convolution.
$$\hat{s}(\omega) \cdot \hat{f}(\omega) = \widehat{s\circledast f}(\omega)$$
The key insight is that convolution in the time or spatial domain becomes a multiplication in the frequency domain, providing methods for signal transformation and filtering.
For instance, to remove specific frequencies, one can construct a filter $\hat{f}$ that is:
- 1 for all frequencies except the target region, say $\omega<\omega_c$
- 0 within the frequencies to be removed, so when $\omega>\omega_c$
By taking the inverse Fourier transform of this filter, $\mathcal{F}^{-1}(\hat{f})$, we can implement the desired frequency filtering. The inverse of such a “window” function is:
$$f(t)=\dfrac{sin(\omega_c t)}{\pi t}$$
Furthermore, you see that one way to compute a convolution is to pad correctly the tensors, take the FFTs, do a (term by term) multiplication and take the inverse Fourier transform of this product:
$$s\circledast f = \mathcal{F}^{-1}(\hat{s}\cdot \hat{f})$$
How to use Nx.conv
The function Nx.conv
as written is useful for Axon
and neural networks as the coefficients are learnt by an iteration process.
However, if the kernel is given, then you will notice that the function as written is not the convolution but the correlation. The good part is that result is same if and only if the kernel is symmetric.
To use it in particular for 1D vectors, you need to Nx.reverse
the kernel. Furthermore, you need to respect the input shape, so you need to Nx.reshape
the tensor.
You have the following parameters:
- padding
- shape
- stride
- permutation
The shape is expected to be: {batch_size, channel, dim_1, ...}
The stride parameter defines how the kernel slides along the axes. A stride of 1 means the kernel moves one step at a time, while larger strides result in skipping over positions.
The batch_size parameter is crucial when working with colored images, as it allows you to convolve each color channel separately while maintaining the output’s color information.
In contrast, for a convolution layer in a neural network, you might average all color channels and work with a grayscale image instead. This reduces computational complexity while still capturing essential features.
Example with a vector
If t = Nx.shape(t)= {n}
, you can simply use t
as a input in a convolution by doing Nx.reshape({1,1,n})
. This respects the pattern:
batch_size=1, channel=1, dim_1=n.
Example with an image
For example, you have a tensor that represents an RGBA image (coloured). Its shape is typically:
{h,w,c} = {256,256,4}
.
You may want to set 4 for the batch_size, to batch per colour layer. We will use a permutation on the tensor. How?
-
you add a new dimension:
Nx.new_axis(t, 0)
. The new shape of the tensor is{1, 256,256,4}
. -
you set the
permutation_input
to permute the tensor with[3,0,1,2]
because the second dimension must be the channel. You understand why you had to add a “fake” dimension to respect the pattern:
batch_size=4, channels=1, dim_1 = 256, dim_2=256
Padding a tensor
You will notice that we “pad” the tensors. This means that we add “zeros” to the signal function in order to build an input that is zero outisde of the number of points ($N$ here). This is done below with padding: :same
or in the polynomial product or in the image filtering.
Some examples on how to use the function Nx.pad
with a 1D tensor to add numbers to a tensor.
t = Nx.tensor([1,2,3])
one_zero_left = Nx.pad(t, 0, [{1,0,0}])
two_5_right = Nx.pad(t, 5, [{0,2,0}])
[t, one_zero_left, two_5_right]
|> Enum.map(&Nx.to_list/1)
[[1, 2, 3], [0, 1, 2, 3], [1, 2, 3, 5, 5]]
With a 2D-tensor, you need to add an extra dimension. For example, take:
m = Nx.tensor([[1,2],[3,4]])
#Nx.Tensor<
s32[2][2]
EMLX.Backend
[
[1, 2],
[3, 4]
]
>
We will successfully:
- add zero-padding on the first and last row
- add zero-padding on the first column and last column
- “suround” the tensor with zeroes
{Nx.pad(m, 0, [{1,1,0}, {0,0,0}]), Nx.pad(m, 0, [{0,0,0}, {1,1,0}]),
Nx.pad(m, 0, [{1,1,0}, {1,1,0}])}
{#Nx.Tensor<
s32[4][2]
EMLX.Backend
[
[0, 0],
[1, 2],
[3, 4],
[0, 0]
]
>,
#Nx.Tensor<
s32[2][4]
EMLX.Backend
[
[0, 1, 2, 0],
[0, 3, 4, 0]
]
>,
#Nx.Tensor<
s32[4][4]
EMLX.Backend
[
[0, 0, 0, 0],
[0, 1, 2, 0],
[0, 3, 4, 0],
[0, 0, 0, 0]
]
>}
Example of convolution and FFT
We use again our signal composed of two signals with frequency 5Hz (or period of 1/5s) and 20Hz (ie period of 1/20s).
$s=\sin(2 \pi 5t ) + \frac14\sin(2\pi 20 t)$
If we want to filter frequencies above $f_c=10$Hz, our filter signal in the frequency domain can be:
$\hat{f}(\omega) = 1$ when $|\omega|<\omega_c = 2\pi f_c$
$\hat{f}(\omega) = 0$ when $|\omega|>\omega_c = 2\pi f_c$
Its “Fourier inverse” is:
$$\dfrac{\sin(2\pi f_c t)}{\pi t}$$
Let’s verify the following:
- if we compute the FFT of our signal $s$, multiply by our window function $\hat{f}$, and then take its IFFT, the resulting signal should have the perturbation at 20Hz removed,
- If we perform a “standard” convolution of $s$ with the corresponding time-domain representation of the filter $\mathcal{F}^{-1}(\hat{f})$, compute its FFT, and analyze the result, we should find that the transformed signal is free of significant perturbations at 20Hz. Specifically, the FFT of the convolved signal should not exhibit a notable amplitude at 20Hz.
Direct convolution in the time (or spatial) domain
We firstly evaluate the “direct” convolution. We plot the convolution of the signal with the filter.
defmodule C do
import Nx.Defn
import Nx.Constants, only: [pi: 0]
defn bins(opts) do
start = opts[:start]
end_point = opts[:end_point]
fs = opts[:fs]
Nx.linspace(start, end_point, n: fs, endpoint: true, type: :f32)
end
defn sample(t) do
f1 = 5; f2 = 20;
Nx.sin(2 * pi() * f1 * t ) + 1/4 * Nx.sin(2 * pi() * f2 * t)
end
defn filter(t, opts) do
fc = opts[:fc]
Nx.sin(2*pi()*fc*t) / (pi()*t)
end
def conv(s,f) do
s = Nx.reshape(s, {1,1,Nx.size(s)})
f = Nx.reshape(f, {1,1,Nx.size(f)})
Nx.conv(s, Nx.reverse(f), padding: :same)
end
defn base(t) do
f1 = 5;
Nx.sin(2 * pi() * f1 * t )
end
end
fs = 200
# center the curve for the FFT of the conv
opts = [start: -0.5, end_point: 0.5, fs: fs ]
x = C.bins(opts)
y_signal = C.sample(x)
fc = 10
y_filter = C.filter(x, fc: fc)
y_conv_s_f = C.conv(y_signal, y_filter) |> Nx.flatten()
signal = %{x: Nx.to_list(x), y: Nx.to_list(y_signal)}
#filter = %{x: Nx.to_list(x), y: Nx.to_list(y_filter)}
conv_s_f = %{x: Nx.to_list(x) , y: Nx.to_list(y_conv_s_f)}
VegaLite.new(width: 700, height: 400)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(signal, only: ["x", "y"])
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative ,title: "time samples (s)")
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "signal amplitude"),
VegaLite.new()
|> VegaLite.data_from_values(conv_s_f, only: ["x", "y"])
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "convolution"),
])
|> VegaLite.resolve(:scale, y: :independent)
In the plot above, we don’t see any perturbation on top of the main period. The signal repeats itself 5 times during this 1s.
To confirm this, we plot the FFT of the convolution $(s\circledast f)(x)$.
It confirms that there is only one spike at 2Hz. This is what we wanted.
n = 25
amplitudes = Nx.fft(y_conv_s_f) |> Nx.abs()
data_conv = %{
x: (for i<- 0..n, do: i),
y: amplitudes[0..n] |> Nx.to_list()
}
VegaLite.new(width: 700, height: 400)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(data_conv, only: ["x", "y"])
|> VegaLite.mark(:bar)
|> VegaLite.encode_field(:x, "x", type: :quantitative, title: "frequency bins")
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "Ampliude")
])
Filter in the frequency domain and reverse
We now work in the frequency domain with the Fourier transform $\hat{s}(\omega)$ of the signal.
We apply a filter $\hat{f}$ in the frequency domain. This function will be equal to 1 for $f Nx.ifft() |> Nx.real()
#——– plot the IFFT data = %{ x: Nx.linspace(start, end_point, n: n+1, endpoint: false) |> Nx.to_list(), y: inverse[0..n] |> Nx.to_list() }
#——– fitting curve xs = Nx.linspace(start, end_point, n: fs, endpoint: false) signal = %{ x: xs |> Nx.to_list(), y: C.base(xs) |> Nx.to_list() }
VegaLite.new(height: 400, width: 700) |> VegaLite.layers([ VegaLite.new() |> VegaLite.data_from_values(data, only: [“x”, “y”]) |> VegaLite.mark(:bar) |> VegaLite.encode_field(:x, “x”, type: :quantitative, title: “time(s)”, scale: [domain: [-0.5, 0.5]]) |> VegaLite.encode_field(:y, “y”, type: :quantitative, title: “amplitutde”), VegaLite.new() |> VegaLite.data_from_values(signal, only: [“x”, “y”]) |> VegaLite.mark(:line) |> VegaLite.encode_field(:x, “x”, type: :quantitative, title: “time(s)”, scale: [domain: [-0.5, 0.5]]) |> VegaLite.encode_field(:y, “y”, type: :quantitative, title: “single signal”), ]) |> VegaLite.resolve(:scale, y: :independent)
## Example: polynomial multiplication using convolution
You can use the convolution product to compute the coefficients of the product of two polynomials.
Lets visualize the convolution product.
Take two polynomials $P_1 = 1+X+X^2$ and $P_2=2+X$. There product is:
$$2+2X+3X^2+X^3+X^4$$
We calculate the convolution of $P_1=[1,1,1]$ with $P_2=[1,0,2]$ (the "kernel" is reversed here).
We are going to slide $P_2$ below $P_1$ and sum the term by term product with the rule that $P_2$ always overlaps $P_1$ by at least one term. This gives us 5 possibilities.
We therefor pad the first tensor with 2 zeros to the left and to the right, in order to calculate 5 sums of 3 products.
t1 0 0 1 1 1 0 0 rev(t2) 1 0 2 0x1 + 0x0 + 1x2 = 2
1 0 2 0x1 + 1x0 + 1x2 = 2
1 0 2 1x1 + 1x0 + 1x2 = 3
1 0 2 1x1 + 1x0 + 0x2 = 1
1 0 2 1x1 + 0x0 + 0x2 = 1
2 2 3 1 1
We will apply a padding +2 "zeros" to the left and the right of the first tensor in the `Nx.conv` function.
defmodule NxPoly do import Nx.Defn
defn tprod(t1, t2) do
t1 = Nx.stack(t1) |> Nx.as_type(:f32)
t1 = Nx.reshape(t1, {1,1, Nx.size(t1)})
t2 = Nx.stack(t2) |> Nx.as_type(:f32)
t2 = Nx.reshape(t2, {1,1, Nx.size(t2)})
Nx.conv(t1, Nx.reverse(t2), padding: [{2,2}])
end end
NxPoly.tprod([1,1,1], [2,0,1]) |> Nx.as_type(:u8) |> Nx.flatten() |> Nx.to_list()
[2, 2, 3, 1, 1]
We obtain the coefficients of the product: $2 + 2X + 3X^2 + X^3 + X^4$
## Example: edge detector with a discret differential kernel
The MNIST dataset contains images of numbers from 0 to 255 organized row wise with a format 28x28 pixels with 1 channel (grey levels).
This means that the pixel $(i,j)$ is at the $i$-th row and $ j$-th column.
{{images_binary, type, images_shape} = _train_images, train_labels} = Scidata.MNIST.download_test() {labels, {:u,8}, label_shape} = train_labels
images = images_binary |> Nx.from_binary(type) |> Nx.reshape(images_shape)
{{_nb, _channel, _height, _width}, _type} = {images_shape, type}
{{10000, 1, 28, 28}, {:u, 8}}
The labels describe what the image is. The 15th image is a "5".
Nx.from_binary(labels, :u8)[15] |> Nx.to_number()
5
To be sure, we display this image with `StbImage`.
> We transform the tensor as `StbImage` expects the format {height, width, channel} whilst the MNIST uses {channel, height, width}.
five = images[15]
Nx.transpose(five, axes: [1,2,0]) |> StbImage.from_nx() |> StbImage.resize(500, 500)
%StbImage{ data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
shape: {500, 500, 1}, type: {:u, 8} }
We can __detect edges__ by performing a discrete differentiation, where the kernel computes the difference between adjacent pixel values. This can be viewed as a discrete approximation of a derivative.
Since the image is slightly blurred, we use a stride of 2, allowing us to skip over intermediate pixels. The kernel is defined as $[-1,0,1]$, which calculates the difference between pixel values separated by two units.
This operation can be applied along the width axis to detect vertical edges or along the height axis to detect horizontal edges.
Given pixel coordinates $(i,j)$ where $i$ is the row and $j$ the column, we compute:
* vertical edges with horizontally adjacent pixels : $\mathrm{pix}(i, j) \cdot (-1) + \mathrm{pix}(i, j+2) \cdot (1)$ and the kernel will be reshaped into $ {1,1,1,3}$.
* horizontal edges with vertically adjacent pixels: $\mathrm{pix}(i, j) \cdot (-1) + \mathrm{pix}(i+2, j) \cdot (1)$ and the kernel will be reshaped into $ {1,1,3, 1}$
Below, we compute two transformed images using these vertical and horizontal edge detectors and display the results.
edge_kernel = Nx.tensor([-1, 0, 1])
vertical_detection = Nx.conv( Nx.reshape(five, {1, 1, 28, 28}) |> Nx.as_type(:f32), Nx.reshape(edge_kernel, {1, 1, 1, 3}), padding: :same )
horizontal_detection = Nx.conv( Nx.reshape(five, {1, 1, 28, 28}) |> Nx.as_type(:f32), Nx.reshape(edge_kernel, {1, 1, 3, 1}), padding: :same )
Nx.concatenate([vertical_detection, horizontal_detection], axis: -1) |> Nx.reshape({28, 56, 1}) |> Nx.as_type({:u, 8}) |> StbImage.from_nx() |> StbImage.resize(500, 1000)
%StbImage{ data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
shape: {500, 1000, 1}, type: {:u, 8} }
## Recover the kernel knowing the output
Suppose we have an image and its corresponding transformed version, produced by a convolution.
Our goal is to recover the convolution kernel "by hand". The will use a simple gradient descent approach.
This involves iteratively updating the kernel through an optimization process. The update to the kernel is proportional to the gradient of a loss function, which we aim to minimize.
The loss function measures the difference between the known output (referred to as vertical_convolution in this case) and the current output image, computed as `Nx.conv(input, current_kernel)`.
We start with an initial kernel of zeros, working with floats instead of integers.
The algorithm uses a parameter `eps` to control the step size (rate of change) during each iteration. This parameter is critical for achieving stable convergence:
* If `eps` is too large, the iterations may become unstable.
* If `eps` is too small, the process may become slow and could converge to a local minimum instead of a potential global optimum (as no convexity is assumed here).
five = Nx.reshape(five, {1, 1, 28, 28}) kernel_init = Nx.tensor([0.0, 0.0, 0.0]) |> Nx.reshape({1,1,1,3})
defmodule GradDesc do import Nx.Defn
defn curr_loss(input, kernel, out) do
curr_out = Nx.conv(input |> Nx.as_type(:f32), kernel, padding: :same)
Nx.mean((curr_out - out)**2)
end
defn update(kernel, input, out, opts) do
der = grad(kernel, fn t -> curr_loss(input, t, out) end)
kernel - der * opts[:eps]
end
defn loop(input, kernel, output, opts) do
{k,_, _,_} =
# for simplicity, we only set the number of steps to run.
while {k = kernel, i=0, inp=input, out=output}, i <= opts[:itermax] do
{update(k, inp, out, opts), i+1, inp, out}
end
{k, curr_loss(input, k, output)}
end end
IO.puts(“Initial loss: #{round(Nx.to_number(GradDesc.curr_loss(five, kernel_init, vertical_detection)))}”)
{k, loss} = GradDesc.loop(five, kernel_init, vertical_detection, itermax: 10, eps: 1.0e-4)
IO.puts(“Loss after iterations: #{Nx.to_number(loss)}”)
Initial loss: 4647 Loss after iterations: 0.004979725461453199
:ok
We expect to find a kernel close to $ [-1,0,1]$ after 10 iterations with a learning rate of 1e-4.
Nx.to_list(k)|> List.flatten()
[-0.9989683032035828, -3.817594006250147e-6, 0.998961329460144]
## Using an Axon model with a convolution layer to predict vertical edges
Let's use `Axon` to recover the convolution kernel. Our model will consist of a single convolution layer.
We will give to our model a series of images and its transform (by the "known" kernel).
The first 15 images do not contain the number 5. We will apply the "vertical" convolution transform to these images and feed them to the model to help it "learn" how to detect vertical edges.
we produce a sample of 14 tuples {input, output = conv(input, kernel)}
which contain some numbers (different from 5)
l = 14 kernel = Nx.tensor([-1.0,0.0,1.0]) |> Nx.reshape({1,1,1,3})
training_data = images[0..l] |> Nx.reshape({l+1, 1,28,28}) |> Nx.as_type(:f32) |> Nx.to_batched(1) |> Stream.map(fn img ->
{img, Nx.conv(img, kernel, padding: :same)}
end)
#Stream<[ enum: 0..14, funs: [#Function<50.105594673/1 in Stream.map/2>, #Function<50.105594673/1 in Stream.map/2>] ]>
The model is a simple convolution layer:
model = Axon.input(“x”, shape: {nil, 1, 28, 28}) |> Axon.conv(1,
channels: :first,
padding: :same,
kernel_initializer: :zeros,
use_bias: false,
kernel_size: 3
)
#Axon< inputs: %{“x” => {nil, 1, 28, 28}} outputs: “conv_0” nodes: 2 >
We train the model with the dataset. Note that it is important to use `compiler: EXLA` for the computation speed.
optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) #optimizer = Polaris.Optimizers.adabelief(learning_rate: 1.0e-2)
params = Axon.Loop.trainer(model, :mean_squared_error, optimizer) |> Axon.Loop.run(training_data, Axon.ModelState.empty(), epochs: 20, compiler: EMLX)
09:41:31.761 [debug] Forwarding options: [compiler: EMLX] to JIT compiler Epoch: 0, Batch: 0, loss: 0.0000000 Epoch: 1, Batch: 0, loss: 4096.1328125 Epoch: 2, Batch: 0, loss: 2760.1577148 Epoch: 3, Batch: 0, loss: 2029.4346924 Epoch: 4, Batch: 0, loss: 1624.6121826 Epoch: 5, Batch: 0, loss: 1374.7448730 Epoch: 6, Batch: 0, loss: 1202.8275146 Epoch: 7, Batch: 0, loss: 1075.4884033 Epoch: 8, Batch: 0, loss: 976.2915649 Epoch: 9, Batch: 0, loss: 895.9644165 Epoch: 10, Batch: 0, loss: 828.9739990 Epoch: 11, Batch: 0, loss: 771.8309326 Epoch: 12, Batch: 0, loss: 722.2226563 Epoch: 13, Batch: 0, loss: 678.5550537 Epoch: 14, Batch: 0, loss: 639.6925049 Epoch: 15, Batch: 0, loss: 604.8014526 Epoch: 16, Batch: 0, loss: 573.2542725 Epoch: 17, Batch: 0, loss: 544.5662842 Epoch: 18, Batch: 0, loss: 518.3546143 Epoch: 19, Batch: 0, loss: 494.3110657
#Axon.ModelState< Parameters: 9 (36 B) Trainable Parameters: 9 (36 B) Trainable State: 0, (0 B) >
We can now check what our Axon model learnt. We compare:
* its "vertical"-convolution
* the predicted image by our model.
five = Nx.reshape(images[15], {1,1,28,28}) |> Nx.as_type(:f32)
model_five = Axon.predict(model, params, five)
conv_five = Nx.conv( Nx.reshape(images[15] , {1, 1, 28, 28}) |> Nx.as_type(:f32), Nx.reshape(edge_kernel , {1, 1, 1, 3}), padding: :same )
Nx.concatenate([model_five, conv_five], axis: -1) |> Nx.reshape({28, 56, 1}) |> Nx.clip(0,255) |> Nx.as_type(:u8) |> StbImage.from_nx() |> StbImage.resize(500, 1500)
%StbImage{ data: <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...>>,
shape: {500, 1500, 1}, type: {:u, 8} }
## Convolution live example
We aim to apply a blurring filter to our webcam feed.
This is achieved by convolving the image with a Gaussian kernel, which computes a weighted average (the "Gaussian mean") of the pixels within the kernel's area.
Most of the tensor manipulations used below are explained in this post:
#### The gaussian kernel
Let's see how the "blur" kernel is built.
* produce two vectors of integers `x` and `y` as `[-2, -1, 0, 1, 2]`.
* compute the gaussian coefficients for the 5x5=25 couples $(x_i,y_j)$
$$
\dfrac1{2\pi\sigma^2}\cdot \exp^{\big(\dfrac{-x_i^2+y_j^2}{2\sigma}\big)}
$$
* Normalize: Scale the coefficients so their sum equals 1. This ensures that the resulting kernel preserves the overall brightness of the image.
#### The convolution
We then perform a convolution between the input (a tensor representing an image) and the given kernel.
The input image tensor is expected to have the shape `{height, width, channel}` with a data type of `:u8`, representing integers ranging from 0 to 255. For example, a 256x256 color PNG image would have the shape `{256, 256, 4}`, where the 4 channels correspond to RGBA values.
To prepare the image for computation, the tensor is converted to floats in the range [0, 1] by dividing each element by 255.
To reduce the computational load, we process only every second pixel. This effectively downsamples the image from, for example, 256x256 to 128x128. This is achieved using strides [2, 2], which skip every other row and column.
As noted earlier, the convolution operation expects the input format `{batch_size, channel, dim_1, dim_2, ...}`.
To accommodate this, we batch the computations by color channel. First, we add an additional dimension to the tensor using `Nx.axis_new`, and then permute the axes to bring the channel dimension into the batch position. This transforms the tensor from `{256, 256, 4}` to `{4, 1, 256, 256}`. The permutation is specified using the `input_permutation` parameter.
Once the convolution is complete:
* The extra singleton dimension is removed using `Nx.squeeze`.
* The resulting floats are rescaled back to the range [0, 255].
* The values are rounded to the nearest integer and cast back to the type `:u8`, as integers are required to produce a valid RGB image.
defmodule Filter do import Nx.Defn import Nx.Constants, only: [pi: 0]
defn gaussian_blur_kernel(opts \ []) do
sigma = opts[:sigma]
range = Nx.iota({5}) |> Nx.subtract(2)
x = Nx.vectorize(range, :x)
y = Nx.vectorize(range, :y)
# Apply Gaussian function to each of the 5x5 elements
kernel =
Nx.exp(-(x * x + y * y) / (2 * sigma * sigma))
kernel = kernel / (2 * pi() * sigma * sigma)
kernel = Nx.devectorize(kernel)
kernel / Nx.sum(kernel)
end
defn apply_kernel(image, kernel, opts \ []) do
# you work on half of the image, reducing from 256 to 128 with strides: [2,2]
opts = keyword!(opts, strides: [2, 2])
input_type = Nx.type(image)
# use floats in [0..1] instead of u8 in [0..255]
image = image / 255
{m, n} = Nx.shape(kernel)
kernel = Nx.reshape(kernel, {1,1,m,n})
# the image has the shape {:h,:w, :c}
# insert a new dimension ( Nx.new_axis does not need the current shape)
# and swap {1, h, w, c} to {c, 1, h, w} for the calculations to be batched on each colour
# then cast into :u8
image
|> Nx.new_axis(0)
|> Nx.conv(kernel,
padding: :same,
input_permutation: [3, 0, 1, 2],
output_permutation: [3, 0, 1, 2],
strides: opts[:strides]
)
|> Nx.squeeze(axes: [0])
|> Nx.multiply(255)
|> Nx.clip(0, 255)
|> Nx.as_type(input_type)
end end
We can take a look into a kernel:
Filter.gaussian_blur_kernel(sigma: 0.5)
#Nx.Tensor< f32[x: 5][y: 5] EMLX.Backend [
[6.962477527849842e-8, 2.8088637918699533e-5, 2.0754853903781623e-4, 2.8088637918699533e-5, 6.962477527849842e-8],
[2.8088637918699533e-5, 0.011331766843795776, 0.08373105525970459, 0.011331766843795776, 2.8088637918699533e-5],
[2.0754853903781623e-4, 0.08373105525970459, 0.618693470954895, 0.08373105525970459, 2.0754853903781623e-4],
[2.8088637918699533e-5, 0.011331766843795776, 0.08373105525970459, 0.011331766843795776, 2.8088637918699533e-5],
[6.962477527849842e-8, 2.8088637918699533e-5, 2.0754853903781623e-4, 2.8088637918699533e-5, 6.962477527849842e-8]
] >
The code below in a custom `Kino.JS.Live` module. It captures the embedded webcam feed from your computer and produces a blurred version using the convolution module defined above.
The JavaScript code captures frames from the webcam into an `OffscreenCanvas`. Every other frame is processed with a convolution filter, and the resulting frame is sent to the Elixir backend as a binary. Kino.JS.Live simplifies this process using a WebSocket and the `pushEvent` mechanism that accepts binary data.
The binary output of the convolution is converted to JPEG (:jpg) format, offering better compression than PNG (:png)—typically achieving a compression ratio of around 1:4.
This compressed binary is then sent back to the JavaScript client via a WebSocket using the `broadcast_event` function on the server side and the `handleEvent` callback on the client side.
To restore the original resolution of 256x256 (after reducing it to 128x128 for processing), the canvas.drawImage function is used, specifying both the source and target dimensions.
defmodule Streamer do use Kino.JS use Kino.JS.Live
def html() do
"""
"""
end
def start(), do: Kino.JS.Live.new(MODULE, html())
@impl true def handle_connect(ctx), do: {:ok, ctx.assigns.html, ctx}
@impl true def init(html, ctx), do: {:ok, assign(ctx, html: html)}
# it receives a video frame in binary form @impl true def handleevent(“new frame”, {:binary,, buffer}, ctx) do
image =
Nx.from_binary(buffer, :u8)
# change shape from {1, 262144} to {256, 256, 4}
|> Nx.reshape({256, 256, 4})
# apply a {5,5} gaussian kernel
kernel = Filter.gaussian_blur_kernel(sigma: 1)
output =
Filter.apply_kernel(image, kernel)
|> StbImage.from_nx()
|> StbImage.to_binary(:jpg)
broadcast_event(ctx, "processed_frame", {:binary, %{}, output})
{:noreply, ctx}
end
asset “main.js” do
"""
export async function init(ctx, html) {
ctx.root.innerHTML = html;
const height = 256, width = 256,
video = document.getElementById("invid"),
renderCanvas = document.getElementById("canvas"),
renderCanvasCtx = renderCanvas.getContext("2d"),
offscreen = new OffscreenCanvas(width, height),
offscreenCtx = offscreen.getContext("2d", { willReadFrequently: true });
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: { width, height },
});
video.srcObject = stream;
let frameCounter = 0;
video.play();
const processFrame = async (_now, _metadata) => {
frameCounter++;
// process only 1 frame out of 2
if (frameCounter % 2 == 0) {
offscreenCtx.drawImage(video, 0, 0, width, height);
const {data} = offscreenCtx.getImageData(0, 0, width, height);
ctx.pushEvent("new frame", [{}, data.buffer]);
}
requestAnimationFrame(processFrame);
};
requestAnimationFrame(processFrame);
} catch (err) {
console.error("Webcam access error:", err);
}
ctx.handleEvent("processed_frame", async ([{}, binary]) => {
const bitmap = await createImageBitmap(new Blob([binary], { type: "image/jpeg" }));
renderCanvasCtx.drawImage(bitmap, 0, 0, bitmap.width, bitmap.height, 0,0, 256,256)
});
}
"""
end end
Streamer.start()
nil
As a side note, the [webGPU API](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API) is available, meaning that the browser is able to call directly the GPU to run computations.
## Convolution and music