Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Median Filter

median_filter/median_filter.livemd

Median Filter

Mix.install([
  {:nx, "~> 0.7.1"},
  {:kino, "~> 0.12.3"},
  {:vega_lite, "~> 0.1.9"},
  {:tucan, "~> 0.3.1"},
  {:kino_vega_lite, "~> 0.1.11"},
  {:exla, "~> 0.7.1"},
  {:benchee, "~> 1.3"}
])

Introduction

In this notebook we are going to implement the Median Filter. It’s a digital filtering technique used to reduce noise from a signal, and we will demonstrate how it works by applying this filter images affected by salt and pepper noise.

image_input = Kino.Input.image("Upload an image with salt and pepper noise")

Images as tensors

Images represented as tensors contain three dimensions: height, width, and channels. The channels dimension is usually of size 3, and each channel represents the grayscale image made of a primary color. For example, an RGB image contains a red channel, a blue channel, and a green channel.

%{file_ref: file_ref, height: height, width: width} = Kino.Input.read(image_input)

img_tensor =
  file_ref
  |> Kino.Input.file_path()
  |> File.read!()
  |> Nx.from_binary({:u, 8})
img_tensor = Nx.reshape(img_tensor, {height, width, 3})
Kino.Image.new(img_tensor)

Utilities

We will implement the median filter for 2D tensors, which means we need to convert our original image into a 2D tensor (i.e. a matrix). So we implement a utility function to convert an RGB image into a grayscale image, meaning that we will only keep one of the three channels.

defmodule Image do
  @rgb_weights Nx.tensor([0.2125, 0.7154, 0.0721])

  @doc ~S"""
  https://github.com/scikit-image/scikit-image/blob/v0.23.2/skimage/color/colorconv.py#L942
  """
  def to_grayscale(img) do
    Nx.dot(img, @rgb_weights)
  end

  def plot(img) do
    {height, width} = Nx.shape(img)
    opts = [height: height, width: width, color_scheme: :greys, reverse: true]
    Tucan.imshow(img, opts)
  end
end

The output is an image with the same height and width, but with the channels dimension removed.

img = Image.to_grayscale(img_tensor)
Image.plot(img)

Implementation

The following is an implementation of 2-dimensional median filter. The boundary strategy in this case is given by how Nx.slice/3 handles the edges when we compute a window.

defmodule Filters do
  def median(t, opts \\ [kernel_size: 3]) do
    kernel_size = Keyword.get(opts, :kernel_size)
    {height, width} = Nx.shape(t)

    for i <- 0..(height - 1) do
      for j <- 0..(width - 1) do
        Nx.slice(t, [i, j], [kernel_size, kernel_size])
        |> Nx.median()
        |> Nx.to_number()
      end
    end
    |> Nx.tensor()
    |> Nx.as_type({:f, 32})
  end
end
img_filt_3 = Filters.median(img, kernel_size: 3)

When we show the filtered image, we see a big improvement. There is still a lot of noise on the image, although now we can see more clearly the underlying picture.

Image.plot(img_filt_3)

Increasing the kernel_size parameter will help to reduce the noise even more. However, this will inevitably start to blur the image, given that increasing the kernel size implies computing bigger windows, hence using more neighbours to compute the median.

img_filt_5 = Filters.median(img, kernel_size: 5)
Image.plot(img_filt_5)

Even though our implementation works fine for this image, we can increase performance by refactoring our code using Nx.Defn module. That will allow us to run our code using a compiler such as EXLA.

Optimization using while construct

defmodule FiltersDefn do
  import Nx.Defn

  defn median(t, opts) do
    kernel_shape = opts[:kernel_shape]

    {height, width} = Nx.shape(t)

    kernel_tensor = Nx.broadcast(0.0, kernel_shape)
    output = Nx.broadcast(0.0, Nx.shape(t))

    {result, _} =
      while {output, {i = 0, t, kernel_tensor, height, width}}, i < height do
        row = Nx.broadcast(0.0, {elem(Nx.shape(t), 1)})

        {ith_row, _} =
          while {row, {j = 0, t, i, kernel_tensor, width}}, j < width do
            {k0, k1} = kernel_tensor.shape

            median =
              Nx.slice(t, [i, j], [k0, k1])
              |> Nx.median()
              |> Nx.broadcast({1})

            {Nx.put_slice(row, [j], median), {j + 1, t, i, kernel_tensor, width}}
          end

        {Nx.put_slice(output, [i, 0], Nx.stack(ith_row, axis: 0)),
         {i + 1, t, kernel_tensor, height, width}}
      end

    result
  end
end
params = [img, [kernel_shape: {3, 3}]]
opts = [compiler: EXLA]

filt = Nx.Defn.jit_apply(&amp;FiltersDefn.median/2, params, opts)
Image.plot(filt)

Refactor using vectorization

defmodule FiltersVec do
  import Nx.Defn

  defn median(t, opts) do
    {k0, k1} = opts[:kernel_shape]

    idx =
      [Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)]
      |> Nx.stack(axis: -1)
      |> Nx.reshape({:auto, 2})
      |> Nx.vectorize(:elements)

    t
    |> Nx.slice([idx[0], idx[1]], [k0, k1])
    |> Nx.median()
    |> Nx.devectorize(keep_names: false)
    |> Nx.reshape(t.shape)
    |> Nx.as_type({:f, 32})
  end
end
params = [img, [kernel_shape: {3, 3}]]
opts = [compiler: EXLA]

filt = Nx.Defn.jit_apply(&amp;FiltersVec.median/2, params, opts)
Image.plot(filt)

Performance Comparisson

opts = [kernel_shape: {3, 3}]

Benchee.run(%{
  "median no defn" => fn -> Filters.median(img, kernel_size: 3) end,
  "median defn-while" => fn -> EXLA.jit_apply(&amp;FiltersDefn.median/2, [img, opts]) end,
  "median vectorized" => fn -> EXLA.jit_apply(&amp;FiltersVec.median/2, [img, opts]) end
})

Bonus - extending to n-dim

image_input = Kino.Input.image("Upload an image with salt and pepper noise")
%{file_ref: file_ref, height: height, width: width} = Kino.Input.read(image_input)

img_tensor =
  file_ref
  |> Kino.Input.file_path()
  |> File.read!()
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({height, width, 3})

Kino.Image.new(img_tensor)
defmodule FiltersVec2 do
  import Nx.Defn

  defn median(t, opts) do
    validate_median_opts!(t, opts)

    idx =
      t
      |> idx_tensor()
      |> Nx.vectorize(:elements)

    t
    |> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape]))
    |> Nx.median()
    |> Nx.devectorize(keep_names: false)
    |> Nx.reshape(t.shape)
    |> Nx.as_type({:f, 32})
  end

  deftransformp validate_median_opts!(t, opts) do
    Keyword.validate!(opts, [:kernel_shape])

    if Nx.rank(t) != Nx.rank(opts[:kernel_shape]) do
      raise ArgumentError, message: "kernel shape must be of the same rank as the tensor"
    end
  end

  deftransformp idx_tensor(t) do
    t
    |> Nx.axes()
    |> Enum.map(&amp;Nx.iota(t.shape, axis: &amp;1))
    |> Nx.stack(axis: -1)
    |> Nx.reshape({:auto, length(Nx.axes(t))})
  end

  deftransformp start_indices(t, idx_tensor) do
    t
    |> Nx.axes()
    |> Enum.map(&amp;idx_tensor[&amp;1])
  end

  deftransformp(kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape))
end
params = [img_tensor, [kernel_shape: {5, 5, 1}]]
opts = [compiler: EXLA]

filt =
  Nx.Defn.jit_apply(&amp;FiltersVec2.median/2, params, opts)
  |> Nx.as_type({:u, 8})
Kino.Image.new(filt)