Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Discrete Fourier Transform with Nx

nx_fft.livemd

Discrete Fourier Transform with Nx

Mix.install(
  [
    {:nx, "~> 0.9.1"},
    {:kino, "~> 0.14.2"},
    {:kino_vega_lite, "~> 0.1.11"},
    {:exla, "~> 0.9.1"},
    {:nx_signal, "~> 0.2.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Using complex numbers

Take the complex number “z” defined by: $1+i$.

Since Nx depends on the library Complex, we can define a complex number by Complex.new/1:

z = Complex.new(1,1)

Its absolute value is $ \sqrt{2} \approx 1.4142$, and its phase is $\pi/4\approx 0.7853$ radians.

sqrt_2 = :math.sqrt(2)
pi_4 = :math.pi()/4

{Complex.abs(z) == sqrt_2, Complex.phase(z) == pi_4}

Its polar form is: $\sqrt2\exp^{i\pi/4}\coloneqq \sqrt2\lparen \cos(\pi/4)+i\sin(\pi/4)\rparen$

We can use Complex.from_polar/2 to build a complex number from its polar definition:

z = Complex.from_polar(:math.sqrt(2), :math.pi()/4)

If we need a tensor from “z”, we would do:

Nx.tensor(z)

We can use directly Nx to build a complex from its cartesian coordinates:

t = Nx.complex(1,1)

and compute its norm and phase:

{Nx.abs(t), Nx.phase(t)}

Most of the Nx API will work normally with complex numbers and tensors. The function sort is an exception since its relies on ordering of values.

For example, we can pass a tensor to Nx.abs: it will apply the function element-wise.

Nx.stack([t,z]) |> Nx.abs()

We also have the imaginary constant $i$. It is defined within Nx.Constants.

import Nx.Constants

Nx.add(1 , i())

For example, we can do:

defmodule Example do
  import Nx.Defn
  import Nx.Constants, only: [i: 0]

  defn rotate(z) do
     i() * z
  end
end

Example.rotate(z)

Advanced Applications - The Discrete Fourier Transform (DFT)

A signal is a sequence of numbers $[(t_1, f(t_1)), \dots,(t_n, f(t_n)) ]$ which we name samples.

The “t” numbers can be viewed as time bins or spatial coordinates, depending upon the subject.

A common aspect people tend to analyze in periodic signals is their frequency composition and intensity. For that, we can use the Discrete Fourier Transform. It takes the samples and outputs a sequence of complex numbers. These numbers represent each sinuaidal component. In other words, it outputs the representation of the sample in the frequency domain.

Nx provides the fftfunction. It uses the Fast Fourrier Transform algorithm, an implementation of the DFT.

Build the signal

The signal we want to analyze will be the sum of two sinusoidal signals, one at 5Hz (ie 5 periods/s), and one at 20Hz (ie 20 periods/s) with the corresponding amplitudes (1, 0.25).

$$ f(t) = \sin(2\pi\cdot 5\cdot t) + \frac14 \sin(2\pi\cdot 20 \cdot t) $$

We build it and decompose and analyze later on.

We build a time series of n points equally spaced with the given duration interval with the Nx function Nx.linspace.

More precisely, we sample at fs=50Hz (meaning 50 samples per second) and our aquisition time is duration = 1s. We will get $50\cdot 1 = 50$ points.

defmodule Signal do
  import Nx.Defn
  import Nx.Constants, only: [pi: 0]

  defn source(t) do
    f1 = 5; f2 = 20;
    Nx.sin(2 * pi() * f1 * t ) + 1/4 * Nx.sin(2 * pi() * f2 * t)
  end

  defn sample(opts) do
    start = opts[:start]
    duration = opts[:duration]
    fs = opts[:fs]
    bins = Nx.linspace(start, duration + start, n: duration * fs, endpoint: false, type: :f32)
    source(bins)
  end
end

We sample our signal at fs=50Hz during 1s:

opts = [start: 0, fs: 50, duration: 1]

sample = Signal.sample(opts)

Analyse the signal with DFT

The DFT algorithm will approximates the original signal. It returns a sequence of complex numbers separated in frequency bins. Each number represents the amplitude and phase for each frequency.

> The number at the index $i$ of the DFT results is a complex number than approximates the amplitude and phase of the sampled signal at the frequency $i$.

dft = Nx.fft(sample)

We are interested in the amplitudes only here. We use the Nx function Nx.abs to obtain the absolute vlue at each point.

Furthermore, we limit our study points to the first half of the “dft” sequence because it is symmetrical.

n = Nx.size(dft)
max_freq_index = div(n, 2)

amplitudes =  Nx.abs(dft)[0..max_freq_index]

# for plotting
data1 = %{
  frequencies: (for i<- 0..max_freq_index, do: i),
  amplitudes: Nx.to_list(amplitudes)
}

VegaLite.new(width: 700, height: 300)
|> VegaLite.data_from_values(data1)
|> VegaLite.mark(:bar)
|> VegaLite.encode_field(:x, "frequencies",
  type: :quantitative,
  title: "frequency (Hz)",
  scale: [domain: [0, 50]]
)
|> VegaLite.encode_field(:y, "amplitudes",
  type: :quantitative,
  title: "amplitutde",
  scale: [domain: [0, 30]]
)

Our synthetized signal has spikes at 5Hz and 20Hz. The amplitude of the spike at 20Hz is approx. a fourth of the amplitude of the spike at 5Hz. This is indeed our incomming signal 🎉.

Visualize the original signal and the IFFT reconstructed

Let’s visualize our original signal. We want a smooth curve so we will sample 200 equidistant points. We select 2 periods of our 5Hz signal. The duration of the sampling is therefor 2/5s = 400ms. This means that our sampling rate is 2ms (ie 500Hz).

We also add the “reconstructed” signal via the Inverse Discrete Fourier Transform available as Nx.ifft.

It will give us 50 values spaced by 1000/50=200ms (we sampled as 50Hz during 1s). Since we display 2/5 = 400ms, we take 20 of them. We display them below as a bar chart. The original signal should envelope the reconstructed signal.

#----------- REAL SIGNAL
# compute 200 points of the "real" signal during 2/5=400ms (twice the main period)

r = 2/5
l = round(50*r)

t = NxSignal.fft_frequencies(r, fft_length: 200)
sample = Signal.source(t)

#----------- RECONSTRUCTED IFFT
# compute the reconstructed IFFT signal (50 points) and sample 20 of them
yr = Nx.ifft(dft) |> Nx.real() 

data = %{
  x: Nx.to_list(t),
  y: Nx.to_list(sample),
}

data_r = %{
  x: (for i<- 0..l-1, do: i/50),
  y: Nx.to_list(yr[0..l-1])
}


VegaLite.new(width: 700, height: 300)
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.data_from_values(data)
  |> VegaLite.mark(:line, tooltip: true)
  |> VegaLite.encode_field(:x, "x", type: :quantitative, title: "time (ms)", scale: [domain: [0, 0.4]])
  |> VegaLite.encode_field(:y, "y", type: :quantitative, title: "signal"),
  VegaLite.new()
  |> VegaLite.data_from_values(data_r)
  |> VegaLite.mark(:bar)
  |> VegaLite.encode_field(:x, "x", type: :quantitative, scale: [domain: [0, 0.4]])
  |> VegaLite.encode_field(:y, "y", type: :quantitative, title: "reconstructed")
  |> VegaLite.encode_field(:order, "x")
])
#|> VegaLite.resolve(:scale, y: :independent)

We see that during 400ms, we have 2 periods of a longer period signal, and 8 of a shorter and smaller perturbation period signal.