Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Median Filter

median_filter/median_filter.livemd

Median Filter

Mix.install([
  {:nx, "~> 0.7.1"},
  {:kino, "~> 0.12.3"},
  {:vega_lite, "~> 0.1.9"},
  {:tucan, "~> 0.3.1"},
  {:kino_vega_lite, "~> 0.1.11"},
  {:exla, "~> 0.7.1"},
  {:benchee, "~> 1.3"},
  {:nx_signal, git: "https://github.com/elixir-nx/nx_signal", branch: "main"}
])

Introduction

In this notebook we are going to implement the Median Filter. It’s a digital filtering technique used to reduce noise from a signal, and we will demonstrate how it works by applying this filter images affected by salt and pepper noise.

image_input = Kino.Input.image("Upload an image with salt and pepper noise")

Images as tensors

Images represented as tensors contain three dimensions: height, width, and channels. The channels dimension is usually of size 3, and each channel represents the grayscale image made of a primary color. For example, an RGB image contains a red channel, a blue channel, and a green channel.

%{file_ref: file_ref, height: height, width: width} = Kino.Input.read(image_input)

img_tensor =
  file_ref
  |> Kino.Input.file_path()
  |> File.read!()
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({height, width, 3})
Kino.Image.new(img_tensor)

Utilities

We will implement the median filter for 2D tensors, which means we need to convert our original image into a 2D tensor (i.e. a matrix). So we implement a utility function to convert an RGB image into a grayscale image, meaning that we will only keep one of the three channels.

defmodule Image do
  @rgb_weights Nx.tensor([0.2125, 0.7154, 0.0721])

  @doc ~S"""
  https://github.com/scikit-image/scikit-image/blob/v0.23.2/skimage/color/colorconv.py#L942
  """
  def to_grayscale(img) do
    Nx.dot(img, @rgb_weights)
    |> Nx.as_type({:u, 8})
  end

  def show(img) do
    {height, width} = Nx.shape(img)
    opts = [height: height, width: width, color_scheme: :greys, reverse: true]
    Tucan.imshow(img, opts)
  end
end

The output is an image with the same height and width, but with the channels dimension removed.

img = Image.to_grayscale(img_tensor)
Image.show(img)

Implementation

The following is an implementation of 2-dimensional median filter. The boundary strategy in this case is given by how Nx.slice/3 handles the edges when we compute a window. Remember that when start_index + length <= dimension, the start_index will be clipped to guarantee that the slice is of the requested length.

defmodule Filters do
  def median(t, opts \\ [kernel_size: 3]) do
    kernel_size = Keyword.get(opts, :kernel_size)
    {height, width} = Nx.shape(t)

    for i <- 0..(height - 1) do
      for j <- 0..(width - 1) do
        Nx.slice(t, [i, j], [kernel_size, kernel_size])
        |> Nx.median()
        |> Nx.to_number()
      end
    end
    |> Nx.tensor()
    |> Nx.as_type({:u, 8})
  end
end
img_filt_3 = Filters.median(img, kernel_size: 3)

When we show the filtered image, we see a big improvement. There is still a lot of noise on the image, although now we can see more clearly the underlying picture.

Image.show(img_filt_3)

Increasing the kernel_size parameter will help to reduce the noise even more. However, this will inevitably start to blur the image, given that increasing the kernel size implies computing bigger windows, hence using more neighbours to compute the median.

img_filt_5 = Filters.median(img, kernel_size: 5)
Image.show(img_filt_5)

Even though our implementation works fine for this image, we can increase performance by refactoring our code using Nx.Defn module. That will allow us to run our code using a compiler such as EXLA.

Optimization using while construct

Let’s refactor the implementation of the median filter using Nx numerical definitions with the defn macro. Using defn only allows functions and macros defined in Nx.Defn.Kernel.

defmodule FiltersDefn do
  import Nx.Defn

  defn median(t, opts) do
    kernel_shape = opts[:kernel_shape]

    {height, width} = Nx.shape(t)

    kernel_tensor = Nx.broadcast(0.0, kernel_shape)
    output = Nx.broadcast(0.0, Nx.shape(t))

    {result, _} =
      while {output, {i = 0, t, kernel_tensor, height, width}}, i < height do
        row = Nx.broadcast(0.0, {elem(Nx.shape(t), 1)})

        {ith_row, _} =
          while {row, {j = 0, t, i, kernel_tensor, width}}, j < width do
            {k0, k1} = kernel_tensor.shape

            median =
              Nx.slice(t, [i, j], [k0, k1])
              |> Nx.median()
              |> Nx.broadcast({1})

            {Nx.put_slice(row, [j], median), {j + 1, t, i, kernel_tensor, width}}
          end

        {Nx.put_slice(output, [i, 0], Nx.stack(ith_row, axis: 0)),
         {i + 1, t, kernel_tensor, height, width}}
      end

    result |> Nx.as_type({:u, 8})
  end
end

By default, Nx uses the compiler Nx.Defn.Evaluator, which is written in pure Elixir, to compile and run numerical functions. However, rather than giving a performance boost, it is slower than the initial implementation. This is because the JIT compilation of the computation graph is happening in Elixir.

args = [img, [kernel_shape: {3, 3}]]
filt = Nx.Defn.jit_apply(&amp;FiltersDefn.median/2, args)
Image.show(filt)

By using EXLA, Elixir will dispatch the JIT compilation to the XLA compiler, which is highly optimized to run tensor operations.

filt = Nx.Defn.jit_apply(&amp;FiltersDefn.median/2, args, compiler: EXLA)
Image.show(filt)

Refactor using vectorization

Taking the defn code optimization a step further, we can refactor the function to use vectorization. Since Nx.slice/3 supports vectorized indices, we can create a vectorized tensor of indices, compute the image slices on each pair of indices $(i,j)$, and finally compute the median of each of those slices.

defmodule FiltersVec do
  import Nx.Defn

  defn median(t, opts) do
    {k0, k1} = opts[:kernel_shape]

    idx =
      [Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)]
      |> Nx.stack(axis: -1)
      |> Nx.reshape({:auto, 2})
      |> Nx.vectorize(:elements)

    t
    |> Nx.slice([idx[0], idx[1]], [k0, k1])
    |> Nx.median()
    |> Nx.devectorize(keep_names: false)
    |> Nx.reshape(t.shape)
    |> Nx.as_type({:u, 8})
  end
end
args = [img, [kernel_shape: {3, 3}]]
filt = EXLA.jit_apply(&amp;FiltersVec.median/2, args)
Image.show(filt)

Performance comparisson

As we will see from the benchmarks, the vectorized version compiled with EXLA is by far the faster one.

opts = [kernel_shape: {3, 3}]

benchmarks =
  Benchee.run(%{
    "median no defn" => fn -> Filters.median(img, kernel_size: 3) end,
    "median defn-while (EXLA)" => fn -> EXLA.jit_apply(&amp;FiltersDefn.median/2, [img, opts]) end,
    "median vectorized (EXLA)" => fn -> EXLA.jit_apply(&amp;FiltersVec.median/2, [img, opts]) end
  })

:ok

Extending to n-dim

NxSignal is a library that provides digital signal processing utilities. It’s built on top of Nx, meaning we can use compilers and backends for faster computations.

A consequence of the work done in this notebook was the addition of the median filter into NxSignal, with support for n-dimensional tensors.

Kino.Image.new(img_tensor)
img_reconstructed = 
  (&amp;NxSignal.Filters.median/2)
  |> EXLA.jit_apply([img_tensor, [kernel_shape: {5, 5, 1}]])
  |> Nx.as_type(:u8)
  
Kino.Image.new(img_reconstructed)

Extra

img_reconstructed_2d = Image.to_grayscale(img_reconstructed)

Image.show(img_reconstructed_2d)

Other filters

defmodule MoreFilters do
  import Nx.Defn
  
  defn mean(t, opts) do
    {k0, k1, k2} = opts[:kernel_shape]
    
    idx =
      [Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1),
        Nx.iota(t.shape, axis: 2)]
      |> Nx.stack(axis: -1)
      |> Nx.reshape({:auto, 3})
      |> Nx.vectorize(:elements)

    t
    |> Nx.slice([idx[0], idx[1], idx[2]], [k0, k1, k2])
    |> Nx.mean()
    |> Nx.devectorize(keep_names: false)
    |> Nx.reshape(t.shape)
    |> Nx.as_type({:f, 32})
  end

  def sobel(t) do
    gx = Nx.tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]])
    gy = Nx.tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]])

    sobel_n(t, gx, gy)
  end

  defn sobel_n(t, tx, ty) do
    {height, width} = Nx.shape(t)

    output = Nx.broadcast(0.0, Nx.shape(t))

    {result, _} =
      while {output, {i = 0, t, height, width, tx, ty}}, i < height do
        row = Nx.broadcast(0.0, {elem(Nx.shape(t), 1)})

        {ith_row, _} =
          while {row, {j = 0, t, i, width, tx, ty}}, j < width do
            slice = Nx.slice(t, [i, j], [3, 3])
            gx = Nx.sum(Nx.multiply(tx, slice))
            gy = Nx.sum(Nx.multiply(ty, slice)) 
            sobel = 
              ((gx * gx) + (gy * gy))
              |> Nx.sqrt()
              |> Nx.broadcast({1})
            
            {Nx.put_slice(row, [j], sobel), {j + 1, t, i, width, tx, ty}}
          end

        {Nx.put_slice(output, [i, 0], Nx.stack(ith_row, axis: 0)),
         {i + 1, t, height, width, tx, ty}}
      end

    result |> Nx.as_type({:u, 8})
  end
end

Mean Filter - Blur effect

(&amp;MoreFilters.mean/2)
|> EXLA.jit_apply([img_reconstructed, [kernel_shape: {9, 9, 1}]])
|> Nx.as_type(:u8)
|> Kino.Image.new()

Invert colors

img_reconstructed
|> Nx.subtract(Nx.broadcast(255, img_reconstructed.shape))
|> Nx.abs()
|> Nx.as_type(:u8)
|> Kino.Image.new()

Sobel Filter - Edge detection

(&amp;MoreFilters.sobel/1)
|> EXLA.jit_apply([img_reconstructed_2d])
|> Image.show()