Median Filter
Mix.install([
{:nx, "~> 0.7.1"},
{:kino, "~> 0.12.3"},
{:vega_lite, "~> 0.1.9"},
{:tucan, "~> 0.3.1"},
{:kino_vega_lite, "~> 0.1.11"},
{:exla, "~> 0.7.1"},
{:benchee, "~> 1.3"}
])
Introduction
In this notebook we are going to implement the Median Filter. It’s a digital filtering technique used to reduce noise from a signal, and we will demonstrate how it works by applying this filter images affected by salt and pepper noise.
image_input = Kino.Input.image("Upload an image with salt and pepper noise")
Images as tensors
Images represented as tensors contain three dimensions: height, width, and channels. The channels dimension is usually of size 3, and each channel represents the grayscale image made of a primary color. For example, an RGB image contains a red channel, a blue channel, and a green channel.
%{file_ref: file_ref, height: height, width: width} = Kino.Input.read(image_input)
img_tensor =
file_ref
|> Kino.Input.file_path()
|> File.read!()
|> Nx.from_binary({:u, 8})
img_tensor = Nx.reshape(img_tensor, {height, width, 3})
Kino.Image.new(img_tensor)
Utilities
We will implement the median filter for 2D tensors, which means we need to convert our original image into a 2D tensor (i.e. a matrix). So we implement a utility function to convert an RGB image into a grayscale image, meaning that we will only keep one of the three channels.
defmodule Image do
@rgb_weights Nx.tensor([0.2125, 0.7154, 0.0721])
@doc ~S"""
https://github.com/scikit-image/scikit-image/blob/v0.23.2/skimage/color/colorconv.py#L942
"""
def to_grayscale(img) do
Nx.dot(img, @rgb_weights)
end
def plot(img) do
{height, width} = Nx.shape(img)
opts = [height: height, width: width, color_scheme: :greys, reverse: true]
Tucan.imshow(img, opts)
end
end
The output is an image with the same height and width, but with the channels dimension removed.
img = Image.to_grayscale(img_tensor)
Image.plot(img)
Implementation
The following is an implementation of 2-dimensional median filter. The boundary strategy in this case is given by how Nx.slice/3
handles the edges when we compute a window.
defmodule Filters do
def median(t, opts \\ [kernel_size: 3]) do
kernel_size = Keyword.get(opts, :kernel_size)
{height, width} = Nx.shape(t)
for i <- 0..(height - 1) do
for j <- 0..(width - 1) do
Nx.slice(t, [i, j], [kernel_size, kernel_size])
|> Nx.median()
|> Nx.to_number()
end
end
|> Nx.tensor()
|> Nx.as_type({:f, 32})
end
end
img_filt_3 = Filters.median(img, kernel_size: 3)
When we show the filtered image, we see a big improvement. There is still a lot of noise on the image, although now we can see more clearly the underlying picture.
Image.plot(img_filt_3)
Increasing the kernel_size
parameter will help to reduce the noise even more. However, this will inevitably start to blur the image, given that increasing the kernel size implies computing bigger windows, hence using more neighbours to compute the median.
img_filt_5 = Filters.median(img, kernel_size: 5)
Image.plot(img_filt_5)
Even though our implementation works fine for this image, we can increase performance by refactoring our code using Nx.Defn
module. That will allow us to run our code using a compiler such as EXLA.
Optimization using while construct
defmodule FiltersDefn do
import Nx.Defn
defn median(t, opts) do
kernel_shape = opts[:kernel_shape]
{height, width} = Nx.shape(t)
kernel_tensor = Nx.broadcast(0.0, kernel_shape)
output = Nx.broadcast(0.0, Nx.shape(t))
{result, _} =
while {output, {i = 0, t, kernel_tensor, height, width}}, i < height do
row = Nx.broadcast(0.0, {elem(Nx.shape(t), 1)})
{ith_row, _} =
while {row, {j = 0, t, i, kernel_tensor, width}}, j < width do
{k0, k1} = kernel_tensor.shape
median =
Nx.slice(t, [i, j], [k0, k1])
|> Nx.median()
|> Nx.broadcast({1})
{Nx.put_slice(row, [j], median), {j + 1, t, i, kernel_tensor, width}}
end
{Nx.put_slice(output, [i, 0], Nx.stack(ith_row, axis: 0)),
{i + 1, t, kernel_tensor, height, width}}
end
result
end
end
params = [img, [kernel_shape: {3, 3}]]
opts = [compiler: EXLA]
filt = Nx.Defn.jit_apply(&FiltersDefn.median/2, params, opts)
Image.plot(filt)
Refactor using vectorization
defmodule FiltersVec do
import Nx.Defn
defn median(t, opts) do
{k0, k1} = opts[:kernel_shape]
idx =
[Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)]
|> Nx.stack(axis: -1)
|> Nx.reshape({:auto, 2})
|> Nx.vectorize(:elements)
t
|> Nx.slice([idx[0], idx[1]], [k0, k1])
|> Nx.median()
|> Nx.devectorize(keep_names: false)
|> Nx.reshape(t.shape)
|> Nx.as_type({:f, 32})
end
end
params = [img, [kernel_shape: {3, 3}]]
opts = [compiler: EXLA]
filt = Nx.Defn.jit_apply(&FiltersVec.median/2, params, opts)
Image.plot(filt)
Performance Comparisson
opts = [kernel_shape: {3, 3}]
Benchee.run(%{
"median no defn" => fn -> Filters.median(img, kernel_size: 3) end,
"median defn-while" => fn -> EXLA.jit_apply(&FiltersDefn.median/2, [img, opts]) end,
"median vectorized" => fn -> EXLA.jit_apply(&FiltersVec.median/2, [img, opts]) end
})
Bonus - extending to n-dim
image_input = Kino.Input.image("Upload an image with salt and pepper noise")
%{file_ref: file_ref, height: height, width: width} = Kino.Input.read(image_input)
img_tensor =
file_ref
|> Kino.Input.file_path()
|> File.read!()
|> Nx.from_binary({:u, 8})
|> Nx.reshape({height, width, 3})
Kino.Image.new(img_tensor)
defmodule FiltersVec2 do
import Nx.Defn
defn median(t, opts) do
validate_median_opts!(t, opts)
idx =
t
|> idx_tensor()
|> Nx.vectorize(:elements)
t
|> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape]))
|> Nx.median()
|> Nx.devectorize(keep_names: false)
|> Nx.reshape(t.shape)
|> Nx.as_type({:f, 32})
end
deftransformp validate_median_opts!(t, opts) do
Keyword.validate!(opts, [:kernel_shape])
if Nx.rank(t) != Nx.rank(opts[:kernel_shape]) do
raise ArgumentError, message: "kernel shape must be of the same rank as the tensor"
end
end
deftransformp idx_tensor(t) do
t
|> Nx.axes()
|> Enum.map(&Nx.iota(t.shape, axis: &1))
|> Nx.stack(axis: -1)
|> Nx.reshape({:auto, length(Nx.axes(t))})
end
deftransformp start_indices(t, idx_tensor) do
t
|> Nx.axes()
|> Enum.map(&idx_tensor[&1])
end
deftransformp(kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape))
end
params = [img_tensor, [kernel_shape: {5, 5, 1}]]
opts = [compiler: EXLA]
filt =
Nx.Defn.jit_apply(&FiltersVec2.median/2, params, opts)
|> Nx.as_type({:u, 8})
Kino.Image.new(filt)