Discrete Fourier Transform with Nx
Mix.install(
[
{:nx, "~> 0.9.1"},
{:kino, "~> 0.14.2"},
{:kino_vega_lite, "~> 0.1.11"},
{:exla, "~> 0.9.1"},
{:nx_signal, "~> 0.2.0"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Using complex numbers
Take the complex number “z” defined by: $1+i$.
Since Nx
depends on the library Complex
, we can define a complex number by Complex.new/1
:
z = Complex.new(1,1)
Its absolute value is $ \sqrt{2} \approx 1.4142$, and its phase is $\pi/4\approx 0.7853$ radians.
sqrt_2 = :math.sqrt(2)
pi_4 = :math.pi()/4
{Complex.abs(z) == sqrt_2, Complex.phase(z) == pi_4}
Its polar form is: $\sqrt2\exp^{i\pi/4}\coloneqq \sqrt2\lparen \cos(\pi/4)+i\sin(\pi/4)\rparen$
We can use Complex.from_polar/2
to build a complex number from its polar definition:
z = Complex.from_polar(:math.sqrt(2), :math.pi()/4)
If we need a tensor from “z”, we would do:
Nx.tensor(z)
We can use directly Nx
to build a complex from its cartesian coordinates:
t = Nx.complex(1,1)
and compute its norm and phase:
{Nx.abs(t), Nx.phase(t)}
Most of the Nx API will work normally with complex numbers and tensors. The function sort
is an exception since its relies on ordering of values.
For example, we can pass a tensor to Nx.abs
: it will apply the function element-wise.
Nx.stack([t,z]) |> Nx.abs()
We also have the imaginary constant $i$. It is defined within Nx.Constants
.
import Nx.Constants
Nx.add(1 , i())
For example, we can do:
defmodule Example do
import Nx.Defn
import Nx.Constants, only: [i: 0]
defn rotate(z) do
i() * z
end
end
Example.rotate(z)
Advanced Applications - The Discrete Fourier Transform (DFT)
A signal is a sequence of numbers $[(t_1, f(t_1)), \dots,(t_n, f(t_n)) ]$ which we name samples.
The “t” numbers can be viewed as time bins or spatial coordinates, depending upon the subject.
A common aspect people tend to analyze in periodic signals is their frequency composition and intensity. For that, we can use the Discrete Fourier Transform. It takes the samples and outputs a sequence of complex numbers. These numbers represent each sinuaidal component. In other words, it outputs the representation of the sample in the frequency domain.
Nx provides the fft
function. It uses the Fast Fourrier Transform algorithm, an implementation of the DFT.
Build the signal
The signal we want to analyze will be the sum of two sinusoidal signals, one at 5Hz (ie 5 periods/s), and one at 20Hz (ie 20 periods/s) with the corresponding amplitudes (1, 0.25).
$$ f(t) = \sin(2\pi\cdot 5\cdot t) + \frac14 \sin(2\pi\cdot 20 \cdot t) $$
We build it and decompose and analyze later on.
We build a time series of n points equally spaced with the given duration
interval with the Nx function Nx.linspace
.
More precisely, we sample at fs=50Hz
(meaning 50 samples per second) and our aquisition time is duration = 1s
. We will get $50\cdot 1 = 50$ points.
defmodule Signal do
import Nx.Defn
import Nx.Constants, only: [pi: 0]
defn source(t) do
f1 = 5; f2 = 20;
Nx.sin(2 * pi() * f1 * t ) + 1/4 * Nx.sin(2 * pi() * f2 * t)
end
defn sample(opts) do
start = opts[:start]
duration = opts[:duration]
fs = opts[:fs]
bins = Nx.linspace(start, duration + start, n: duration * fs, endpoint: false, type: :f32)
source(bins)
end
end
We sample our signal at fs=50Hz during 1s:
opts = [start: 0, fs: 50, duration: 1]
sample = Signal.sample(opts)
Analyse the signal with DFT
The DFT algorithm will approximates the original signal. It returns a sequence of complex numbers separated in frequency bins. Each number represents the amplitude and phase for each frequency.
> The number at the index $i$ of the DFT results is a complex number than approximates the amplitude and phase of the sampled signal at the frequency $i$.
dft = Nx.fft(sample)
We are interested in the amplitudes only here. We use the Nx function Nx.abs
to obtain the absolute vlue at each point.
Furthermore, we limit our study points to the first half of the “dft” sequence because it is symmetrical.
n = Nx.size(dft)
max_freq_index = div(n, 2)
amplitudes = Nx.abs(dft)[0..max_freq_index]
# for plotting
data1 = %{
frequencies: (for i<- 0..max_freq_index, do: i),
amplitudes: Nx.to_list(amplitudes)
}
VegaLite.new(width: 700, height: 300)
|> VegaLite.data_from_values(data1)
|> VegaLite.mark(:bar)
|> VegaLite.encode_field(:x, "frequencies",
type: :quantitative,
title: "frequency (Hz)",
scale: [domain: [0, 50]]
)
|> VegaLite.encode_field(:y, "amplitudes",
type: :quantitative,
title: "amplitutde",
scale: [domain: [0, 30]]
)
Our synthetized signal has spikes at 5Hz and 20Hz. The amplitude of the spike at 20Hz is approx. a fourth of the amplitude of the spike at 5Hz. This is indeed our incomming signal 🎉.
Visualize the original signal and the IFFT reconstructed
Let’s visualize our original signal. We want a smooth curve so we will sample 200 equidistant points. We select 2 periods of our 5Hz signal. The duration of the sampling is therefor 2/5s = 400ms. This means that our sampling rate is 2ms (ie 500Hz).
We also add the “reconstructed” signal via the Inverse Discrete Fourier Transform available as Nx.ifft
.
It will give us 50 values spaced by 1000/50=200ms (we sampled as 50Hz during 1s). Since we display 2/5 = 400ms, we take 20 of them. We display them below as a bar chart. The original signal should envelope the reconstructed signal.
#----------- REAL SIGNAL
# compute 200 points of the "real" signal during 2/5=400ms (twice the main period)
r = 2/5
l = round(50*r)
t = NxSignal.fft_frequencies(r, fft_length: 200)
sample = Signal.source(t)
#----------- RECONSTRUCTED IFFT
# compute the reconstructed IFFT signal (50 points) and sample 20 of them
yr = Nx.ifft(dft) |> Nx.real()
data = %{
x: Nx.to_list(t),
y: Nx.to_list(sample),
}
data_r = %{
x: (for i<- 0..l-1, do: i/50),
y: Nx.to_list(yr[0..l-1])
}
VegaLite.new(width: 700, height: 300)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(data)
|> VegaLite.mark(:line, tooltip: true)
|> VegaLite.encode_field(:x, "x", type: :quantitative, title: "time (ms)", scale: [domain: [0, 0.4]])
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "signal"),
VegaLite.new()
|> VegaLite.data_from_values(data_r)
|> VegaLite.mark(:bar)
|> VegaLite.encode_field(:x, "x", type: :quantitative, scale: [domain: [0, 0.4]])
|> VegaLite.encode_field(:y, "y", type: :quantitative, title: "reconstructed")
|> VegaLite.encode_field(:order, "x")
])
#|> VegaLite.resolve(:scale, y: :independent)
We see that during 400ms, we have 2 periods of a longer period signal, and 8 of a shorter and smaller perturbation period signal.