Mandelbrot set with Elixir Nx and Zigler
Mix.install(
[
{:nx, "~> 0.9"},
{:exla, "~> 0.9"},
{:kino, "~> 0.14.2"},
{:zigler, "~> 0.13.3"},
],
config: [nx: [default_backend: EXLA.Backend]]
)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)
Warning
If you run into an error above, this means that you don’t have Zig. Comment the package Zigler above.
This means you can’t run the last modules where we use inline Zig code.
:observer.start()
Table of contents
- Introduction
- Play with orbits
- Def and Defn
- Algorithm
- Orbit and iteration number
- Some colour palettes
- Computing Mandelbrot set
- Parallelise it
- ExZoom in and out
- Zoom in with Zig
- Static image with Zig and Nx and Kino
Introduction
We want to produce an image that represents the beautiful Mandelbrot set. With one module below, you can zoom into the fractal.
Source:
We use the Nx library with the EXLA backend to speed up the computations.
We also propose to run the equivalent code in Zig in this Livebook if you want extra speed. This happens thanks to the Zigler library. The Zig code returns a binary that Nx is able to consume and Kino to display.
What is a Mandelbrot set?
In a “Mandlebrot image”, each pixel has a colour repesenting how fast the underlying point “escapes” when calculating its iterates under a certain function.
What is an underlying point?
A pixel has some coordinates [i,j]. For example, in a 1024 × 768 image (WIDTH x HEIGHT), the row number varies from from 0 to 1023 and column number from 0 to 767.
We transform these couples of integers (i,j) into a point into a 2D plane. This map “quantitizes” the 2D plane.
Here, the 2D “real” plan is defined by the upper left corner, say (-2,1), and bottom right corner, say (1,-1).
How? We have a linear mapping between the couple (i,j) and a point (x,y) in the defined zone. For example, the pixel (0,0) becomes (-2,1) and the pixel (767, 1023) becomes (1,-1).
What is iterating?
We will iterate the function: x -› x*x +c where c is a given number and x the variable.
We start with:
z0 = f(0) = c
z1 = f(z0) = z0 * z0 + c
z2 = f(z1) = z1 * z1 + c
...
Let’s take an example. The module below calculates the iterations x(n) = f(x(n-1)) by a simple recursion.
The sets of these iterates of c is called its orbit .
defmodule Simple do
def p(x,c), do: x**2 + c
def iterate(0,c), do: c
def iterate(n,c), do: p(iterate(n-1, c), c)
end
We calculate the first elements of its orbit and evaluate how does the point c=1 behaves. It looks like it will diverge to infinity.
c = 1
{ c,
Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c), Simple.iterate(5,c), Simple.iterate(6,c),
}
gives:
{1, 1, 2, 5, 26, 677, 458330}
On the other side, the point c=-1 seems well bahaved: the orbit has only two values, 0 and - 1, and is periodic.
c = -1
{ c,
Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c),Simple.iterate(5,c),Simple.iterate(6,c),
}
gives:
{-1, -1, 0, -1, 0, -1, 0}
In the examples above, we took a simple “real” number.
For the Mandelbrot set, we use the complex representation of a point: (x,y) -> x + y*i where i is the imaginary number, Nx.Constants.i.
So, each pixel (i,j) is mapped to a complex number c = projection(i,j), and we want to evaluate how do the iterates of c behave under this iteration starting at z0 = 0.
Iteration number?
We are interested by assigning an iteration number to each c.
The number of iterations needs to be bounded (think of a periodic orbit). Let max_iter be the maximum number of iterations, for example 100.
If the orbit of c remains bounded, we assign an iteration number to max_iter.
If it escapes, meaning one iterate has a norm greater than 2, then we calculate the first index such that the iterate norm is greater than 2 (in absolute value as a complex, or its norm as a point).
If we inspect the orbits, we discover ncie figures. The Livebook below lets you play with the plan and displays orbits on-click.
Play with orbits
Livebook to compute orbits:
Def and Defn
We will use two types of functions:
-
Elixirfunctions usingdef -
Nxfunctions usingdefn; these use a special backend (EXLA with CPU or GPU if any). Within a numerical function, all the arguments are treated as tensors. To pass a non-tensor argment, use anoptskeyword.
The points of the 2D plane will be represented as complex numbers as the Mandelbrot map works with complex numbers.
The function z(n+1) = z(n) * z(n) + c takes a complex number and returns a complex number.
Algorithm
Source:
Input: image dimensions (eg w x h of 1500 x 1000), max iteration (eg 100)
Iterate over each pixel (i,j):
- map it into the 2D plane: compute its “complex coordinates”
- compute the iteration number
- compute a colour
-
Sum-up and draw from the final tensor with
Kino.
An example of a zoom into an area. You can go as deep as you want with these fractals objects.
Pixel to complex plan mapping
This module transforms a couple (i,j) into a complex number. It is a numerical function.
> a tensor can be of type complex c64. A numerical Nx function natively understands complex numbers. Note that we use the built-in constant i().
defmodule Pixel do
import Nx.Defn
import Nx.Constants, only: [i: 0]
defn map(index, {h,w}, {top_left_x, top_left_y, bottom_right_x,bottom_right_y}) do
scale_x = (bottom_right_x-top_left_x) / (w-1)
scale_y = (bottom_right_y-top_left_y) / (h-1)
# building a complex typed tensor
top_left_x + index[1] * scale_x + i() * (top_left_y + index[0] * scale_y)
end
end
Orbit and iteration number
This module computes the iteration number for a given input c. It is a numerical function.
How does this work?
-
If
|c|>2, then this point is unstable. Otherwise, we have to compute for each point whether it stays bounded or not. -
If it is bounded, we get
max_iter, otherwise a lower value.
We cannot use the recursion form we did earlier because numerical functions don’t accept several headers as plain Elixir.
Instead we use the "Nx"-while loop. Note how we use the Nx versions of cond. In this context, true is the tensor 1.
defmodule Orbit do
import Nx.Defn
defn poly(z,c), do: z*z + c
# for speed, we don't use the norm (square root involved), only the square of the norm
defn sq_norm(z), do: Nx.real(Nx.conjugate(z) * z)
defn number(c,opts) do
max_iter = opts[:imax]
#h = opts[:h]
#w = opts[:w]
condition = (Nx.real(c) +1) ** 2 + (Nx.imag(c)**2)
cond do
# points in first cardioid are all stable. Save on iterations
Nx.less(condition, 0.0625) ->
max_iter
# these points are unbounded whenever the norm is > 2
Nx.greater(sq_norm(c), 4) ->
0
# evaluate each point as it can be bounded or not in the disk 2
1 ->
# less efficient going up as we need to build a h x w zero tensor
#z0 = Nx.complex(0.0, 0.0) |> Nx.broadcast({h,w}) |> Nx.vectorize([:rows, :cols])
# {i, _, _, _} =
#while {i=0, z=z0, c, max_iter}, Nx.less(i,max_iter) and Nx.less(sq_norm(z), 4) do
# {i+1, poly(z,c), c, max_iter}
#end
# more efficient going down as we can start at the input value "c"
{i, _, _} =
while {i=max_iter, z=c, c}, Nx.greater(i,1) and Nx.less(sq_norm(z), 4) do
{i-1, poly(z,c), c}
end
#i
max_iter - i + 1
end
end
end
#Orbit.number(Nx.complex(0.6, 0.2), imax: 50)
#Nx.complex(0.2, 0.2) |> Nx.broadcast({400,500})
Some colour palettes
Each iteration number is an integer n. We want to associate a colour [r(n),g(n),b(n)].
This will help us to visualise which point of the complex plane is stable, and if not how fast it escapes.
The choice below is just an example. The first render uses the rbg function whilst the second the rgb2. Other choices can be made.
defmodule Colour do
import Nx.Defn
defn rgb(n, max_iter) do
n = n / max_iter
cond do
Nx.equal(n, 0) ->
Nx.stack([255, 255, 0])
Nx.less(n, 0.5) ->
s = n * 2
r = 255 * (1 - s)
g = 255 * (1 - s/2)
b = 127 + 128 * s
Nx.stack([r, g, b])
true ->
s = 2 * n - 1
r = 0
g = 127 * (1 - s)
b = 255 * (1 - s)
Nx.stack([r, g, b])
end
end
defn rgb3(n, max_iter) do
v = Nx.remainder(n,16)
cond do
Nx.equal(n, max_iter) -> Nx.stack([0,0,0])
Nx.equal(v,0) -> Nx.stack([66, 30, 15])
Nx.equal(v,1) -> Nx.stack([25, 7, 26])
Nx.equal(v,2) -> Nx.stack([9, 1, 47])
Nx.equal(v,3) -> Nx.stack([4, 4, 73])
Nx.equal(v,4) -> Nx.stack([0, 7, 100])
Nx.equal(v,5) -> Nx.stack([12, 44, 138])
Nx.equal(v,6) -> Nx.stack([24, 82, 177])
Nx.equal(v,7) -> Nx.stack([57, 125, 209])
Nx.equal(v,8) -> Nx.stack([134, 181, 229])
Nx.equal(v,9) -> Nx.stack([211, 236, 248])
Nx.equal(v,10) -> Nx.stack([241, 233, 191])
Nx.equal(v,11) -> Nx.stack([248, 201, 95])
Nx.equal(v,12) -> Nx.stack([255, 170, 0])
Nx.equal(v,13) -> Nx.stack([204, 128, 0])
Nx.equal(v,14) -> Nx.stack([153, 87, 0])
Nx.equal(v,15) -> Nx.stack([106, 52, 3])
1 -> Nx.stack([0,0,0])
end
end
defn rgb4(n, max_iter) do
v = Nx.remainder(n,16)
cond do
Nx.equal(n, max_iter) -> Nx.stack([0,0,0, 255])
Nx.equal(v,0) -> Nx.stack([66, 30, 15, 255])
Nx.equal(v,1) -> Nx.stack([25, 7, 26, 255])
Nx.equal(v,2) -> Nx.stack([9, 1, 47, 255])
Nx.equal(v,3) -> Nx.stack([4, 4, 73, 255])
Nx.equal(v,4) -> Nx.stack([0, 7, 100, 255])
Nx.equal(v,5) -> Nx.stack([12, 44, 138, 255])
Nx.equal(v,6) -> Nx.stack([24, 82, 177, 255])
Nx.equal(v,7) -> Nx.stack([57, 125, 209, 255])
Nx.equal(v,8) -> Nx.stack([134, 181, 229, 255])
Nx.equal(v,9) -> Nx.stack([211, 236, 248, 255])
Nx.equal(v,10) -> Nx.stack([241, 233, 191, 255])
Nx.equal(v,11) -> Nx.stack([248, 201, 95, 255])
Nx.equal(v,12) -> Nx.stack([255, 170, 0, 255])
Nx.equal(v,13) -> Nx.stack([204, 128, 0, 255])
Nx.equal(v,14) -> Nx.stack([153, 87, 0, 255])
Nx.equal(v,15) -> Nx.stack([106, 52, 3, 255])
1 -> Nx.stack([0,0,0, 255])
end
end
end
Computing the Mandelbrot set
The final module
We reassemble the tensor into the desired format for Kino to consume it and display.
> Note that you want to pass arguments into a defn function that you don’t want to be treated as tensors, you need to use a keyword list or a map.
defmodule Mandelbrot do
import Nx.Defn
defn compute(opts) do
top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = - 1.2;
defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}
h = opts[:h]
w = opts[:w]
max_iter = opts[:max_iter]
# build the tensor [[0,0],, ...[0,m], [1,1]...[n,m]]. Thks to PValente
iota_rows = Nx.iota({h}, type: :u16) |> Nx.vectorize(:rows)
iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
cross_product = Nx.stack([iota_rows, iota_cols])
Pixel.map(cross_product,{h,w}, defining_points)
|> Orbit.number(imax: max_iter) #, h: h, w: w)
|> Colour.rgb(max_iter)
|> Nx.devectorize()
|> Nx.reshape({h, w, 3})
|> Nx.as_type(:u8)
end
end
Depending on your machine, the computation below can be lengthy.
If you want to simply evaluate, set low values h = w = 300 and max_iter = 100.
h = 400; w = 502
Mandelbrot.compute(h: h, w: w, max_iter: 100)
|> Kino.Image.new()
Parallelise it
When the resolution of the image increases, it is interesting to parallelise the computations.
We divide the image in horizontal bands by fixing the numbers of rows. Then we spawn as many tasks as needed.
> This is worth only if the size of the image is large enough as this comes with non negligible overhead.
Lastly, another possible optimisation is to remark that the image is symmetric when you compute the full image. You can compute half of the image and stack to it the reversed tensor. This gives an extra boost.
This is done in StreamMandelbrot.parall_sym
defmodule StreamMandelbrot do
import Nx.Defn
# bottom_right_y = 0 <--- we computate only plane y>0
defn compute_half(cross_product, %{h: h, w: w, max_iter: max_iter}) do
top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = 0;
defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}
Pixel.map(cross_product,{h,w}, defining_points)
|> Orbit.number(imax: max_iter)
|> Colour.rgb3(max_iter)
|> Nx.devectorize()
|> Nx.as_type(:u8)
end
#--------------------------------------------------------------------------------
# Fix the number of rows and spawn concurrently and compute half (for full image)
#--------------------------------------------------------------------------------
def paral_sym(%{h: h, w: w} = opts) do
half_h = div(h,2)
rows_per_task = 30
half_new = half_h - rem(half_h, rows_per_task)
nb_tasks = div(half_new, rows_per_task)
opts = Map.merge(opts, %{h: half_new })
tasks_half_tensor =
for task_id <- 0.. nb_tasks-1 do
Task.async(fn ->
iota_rows =
Nx.iota({rows_per_task}, type: :u16)
|> Nx.add(task_id * rows_per_task)
|> Nx.vectorize(:rows)
# full width
iota_cols =
Nx.iota({w}, type: :u16)
|> Nx.vectorize(:cols)
cross_product = Nx.stack([iota_rows, iota_cols])
Nx.Defn.jit_apply(fn t ->
computed_rows = compute_half(t, opts)
{task_id, computed_rows}
end,
[cross_product]
)
end)
end
half_tensor =
Task.await_many(tasks_half_tensor, :infinity)
|> Enum.sort(fn {x,_}, {y,_} -> Nx.to_number(x)<= Nx.to_number(y) end)
|> Enum.map(&(elem(&1, 1)))
|> Nx.stack()
|> Nx.devectorize()
flah_tensor =
Nx.tensor(half_tensor, names: [:cpus, :rows, :cols, :rgbs])
|> Nx.reverse(axes: [:cpus, :rows])
Nx.stack([half_tensor, flah_tensor]) |> Nx.reshape({half_new * 2, w, 3})
end
end
We get the expected performance boost. We are able to incrase the resolution.
> For the fun, we changed the colour palette.
h = 800; w = 1001
StreamMandelbrot.paral_sym( %{h: h, w: w, max_iter: 150})
|> Kino.Image.new()
ExZoom in and out
In this module, we are able to click on a point of the image and zoom in and dive into the fractal.
We pass binary between the browser and the Livebook used as a server. This happens thanks to Kino.JS.Live who supports using binary data.
Given some bounds of the 2D-plan within our canvas, we will pass “quantitized” coordinates of the 2D-plan. To each (i,j) of the canvas, we compute in the browser the “real” coordinates (x,y) and pass this as a binary to the Livebook. For example, when you have 600 x 800 points as floats (4 bytes per float for f32), we pass 600x800x2x4 = 3_840_000 = 3.8MB of data.
> send the binary within an ArrayBuffer container, and allocate an Float32Array view.
With Nx.from_binary(binary, :f32), we can enter the Nx world. We parallelise the computations by running Task.async whe we chunk the rows, and reassemble the resulting tensors by the task_id .
We then send back to the browser the resulting tensor with Nx.to_binary, the size becomes 600x800x3 = 1.44MB. This is because we are sending now u8 integers representing the R,G,B of each point.
We don’t want to use data as strings (data:image base 64) as this increases the weight and thus slows down the rendering.
We can transform the binary into RGBA and display it into the canvas by building an ImageData.
> use the Uint8Array container and allocate an Uint8ClampedArray view. Then use the ImageData from the Canvas API and window.createImageBitmap to fill in the canvas.
defmodule ExZoom do
import Nx.Defn
@h 600
@w 800
@imax 250
@zoom_ratio 5
@rows_per_task 20
# Nx computations ------------------------------------
defn compute(slice, %{max_iter: max_iter}) do
Orbit.number(slice, imax: max_iter)
|> Colour.rgb4(max_iter)
|> Nx.devectorize()
|> Nx.as_type(:u8)
end
defn to_complex(t), do: Nx.complex(t[0], t[1])
# we want to "complexify" the points thus we need to vectorize both dimensions
defn complexify(t, opts) do
h = opts[:h]
w = opts[:w]
Nx.reshape(t,{h, w, 2})
|> Nx.vectorize([:r, :c])
|> to_complex()
end
# <--------------------------------------------------
def consume(binary, %{h: h, w: w, max_iter: max_iter}) do
t = Nx.from_binary(binary, :f64) |> complexify(h: h, w: w)
# the height is truncated as you cannot stack tensor of different shapes easily
nb_tasks = div(h, @rows_per_task)
opts = %{max_iter: max_iter}
computations =
for task_id <- 0..nb_tasks-1 do
Task.async(fn ->
# we shift the start index by the number of rows already consummed
start_row = task_id * @rows_per_task
slice =
Nx.devectorize(t)
|> Nx.slice([start_row, 0], [@rows_per_task, w])
|> Nx.vectorize([:row, :col])
# apply our Nx pipeline to the slice
Nx.Defn.jit_apply(fn t ->
computed_rows = compute(t, opts)
{task_id, computed_rows}
end,
[slice]
)
end)
end
binary =
Task.await_many(computations, :infinity)
# re-order as per the task_id we retunrned
|> Enum.sort(fn {x,_}, {y,_} -> Nx.to_number(x)<= Nx.to_number(y) end)
|> Enum.map(&(elem(&1, 1)))
|> Nx.stack()
# send to the browser as binary
|> Nx.to_binary()
# the browser needs to know the nb of rows as it may be truncated
{nb_tasks * @rows_per_task, binary}
end
# Kino.JS.Live ------------------------------------------------------
use Kino.JS
use Kino.JS.Live
def canvas(h, w) do
"""
"""
end
def start(), do: Kino.JS.Live.new(__MODULE__, canvas(@h, @w))
@impl true
def init(html, ctx) do
ctx = assign(ctx, %{max_iter: @imax, h: @h, w: @w})
{:ok, assign(ctx, html: html)}
end
@impl true
def handle_connect(ctx) do
broadcast_event(ctx, "zoom_ratio", %{"zoom_ratio" => @zoom_ratio})
{:ok, ctx.assigns.html, ctx}
end
@impl true
def handle_event("clicked", {:binary,info, buffer}, %{assigns: assigns} = ctx) do
%{h: h, w: w, max_iter: max_iter} = assigns
# dbg(byte_size(buffer))
case info do
%{"boundary" => %{"max_x" => right, "max_y" => top, "min_x" => left, "min_y" => bottom}} ->
IO.inspect({right-left, top-bottom}, label: "zoom dim: ")
_ ->
IO.inspect({3, 1}, label: "zoom dim: ")
end
{nb_rows, binary} = consume(buffer, %{h: h, w: w, max_iter: max_iter})
broadcast_event(ctx,"new", {:binary, %{nb_rows: nb_rows}, binary})
{:noreply, ctx}
end
asset "main.js" do
"""
export function init(ctx, html) {
ctx.root.innerHTML = html;
const canvas = document.getElementById("myCanvas");
const canvasCtx = canvas.getContext("2d");
const canvasWidth = canvas.width, canvasHeight = canvas.height;
// initialize the zoom factor
let zoomRatio = undefined;
ctx.handleEvent("zoom_ratio", ({zoom_ratio}) => zoomRatio = zoom_ratio);
// initial bounds
let boundary = {min_x: -2, max_y: 1, max_x: 1, min_y: -1};
// Function to convert canvas coordinate to complex coordinate
function toCartesian(canvasX, canvasY, bounds) {
const {min_x, max_x, min_y, max_y} = bounds;
const x = min_x + (canvasX / (canvasWidth-1)) * (max_x - min_x);
const y = max_y - (canvasY / (canvasHeight-1)) * (max_y - min_y);
return { x, y };
}
function recalculateBounds(canvasX, canvasY, bounds, options = {}) {
// Get the clicked point's corresponding 2D plane coordinates
const { x: newCenterX, y: newCenterY } = toCartesian(canvasX, canvasY, bounds)
const { zoomIn = true} = options;
const {min_x, max_x, min_y, max_y} = bounds;
// Compute the width and height of the new zoomed rectangle
const oldWidth = max_x - min_x;
const oldHeight = max_y - min_y;
const factor = zoomIn ? 1/zoomRatio : zoomRatio;
const newWidth = oldWidth * factor;
const newHeight = oldHeight * factor;
return {
min_x: newCenterX - newWidth / 2,
max_x: newCenterX + newWidth / 2,
min_y: newCenterY - newHeight / 2,
max_y: newCenterY + newHeight / 2
};
}
// received from server -----------------------------------
ctx.handleEvent("new", ([{nb_rows}, binary]) => {
//const len = binary.byteLength;
const imageData = new ImageData(
new Uint8ClampedArray(binary),
canvasWidth,
nb_rows
);
// we send RGBA, 4 bytes so we can immediately render
// to canvas with "createImageBitmap"
createImageBitmap(imageData).then(bitmap =>
canvasCtx.drawImage(bitmap, 0, 0)
);
});
// push from browser -------------------------------------
canvas.addEventListener("click", (e) => {
//if (!e.ctrlKey) return; // <-- click + ctrl pressed to zoom it
const rect = canvas.getBoundingClientRect();
const canvasX = e.clientX - rect.left;
const canvasY = e.clientY - rect.top;
const newBounds = recalculateBounds(canvasX, canvasY, boundary, {zoomIn: !e.shiftKey});
//-- send data to the server -->
Object.assign(boundary, newBounds);
const buffer = getPoints(boundary)
ctx.pushEvent("clicked", [{boundary: newBounds}, buffer])
});
// compute all cartesians coordinates in the current canvas
function getPoints(bounds) {
const floatArray = new Float64Array(canvasWidth * canvasHeight * 2); // 2 floats per pixel, 4 bytes each
let index = 0;
for (let n = 0; n < canvasWidth * canvasHeight; n++) {
const { x, y } = toCartesian(
n % canvasWidth, // i
Math.floor(n / canvasWidth), // j
bounds
);
floatArray[n * 2] = x;
floatArray[n * 2 + 1] = y;
}
return floatArray.buffer;
}
// first render---------------
const buffer = getPoints(boundary)
ctx.pushEvent("clicked", [{}, buffer])
}
"""
end
end
You can CLICK to zoom into the point of your choice. The zoom factor is 5 (you can change it via the module argument).
You can dezoom with +SHIFT+CLICK.
ExZoom.start()
Zoom in with Zig
We can provide extra speed by using Zig code in Elixir within a Livebook.
The Livebook orchestrates between the client Javascript and the Zig snippet.
Zigler offers a remarkable documentation.
!! You may to have Zig installed on your machine.
In the Livebook, we add the dependencies (in the first cell):
Mix.install([{:zigler, "~> 0.13.3"},{:zig_get, "~> 0.13.1"},])
With the Zigler, you can inline Zig code as below.
The code below runs the same algorithm and runs OS threads (bands of rows) for concurrency.
> we use the beam memory allocator from the library.
We return the data from Zig as a binary to Elixir that will pass it to Javascript.
To get extra speed and less bandwidth consumption, the Livebook will pass it to the browser into an ImageData container and render into the canvas via window.createImageBitmap.
You can CLICK into the image and explore with the power of Zig (SHIFT+CLICK do “de-zoom”). We use a factor 5.
We pass the boundary of the canvas (in the 2D-plan coordiantes) to Zig.
The Zig NIF will compute by spawning OS threads (as much as the numbers of CPU cores).
It will return the RGBA values for each point of the canvas as a binary.
defmodule ZigZoom do
use Zig, otp_app: :zigler,
# nifs: [..., generate: [:threaded]]
release_mode: :fast
~Z"""
const beam = @import("beam");
const std = @import("std");
const Point = struct { x: f64, y: f64 };
const point_size = @sizeOf(Point);
const Bounds = struct { topLeft: Point, bottomRight: Point, cols: usize, rows: usize };
/// nif: generate/5 Threaded
pub fn generate(
rows: usize,
cols: usize,
imax: usize,
topLeft: [2]f64,
bottomRight: [2]f64
) !beam.term {
const colours = try beam.allocator.alloc(u8, rows * cols * 4);
defer beam.allocator.free(colours);
const cpus = try std.Thread.getCpuCount();
var threads = try beam.allocator.alloc(std.Thread, cpus);
defer beam.allocator.free(threads);
const base_rows_per_thread = rows / cpus;
const remainder_rows = rows % cpus;
const bounds = Bounds{
.topLeft = Point{ .x = topLeft[0], .y = topLeft[1] },
.bottomRight = Point{ .x = bottomRight[0], .y = bottomRight[1] },
.cols = cols,
.rows = rows,
};
for (0..cpus) |thread_id| {
const start_row = thread_id * base_rows_per_thread;
const extra_rows = if (thread_id == cpus - 1) remainder_rows else 0;
const rows_per_thread = base_rows_per_thread + extra_rows;
const args = .{colours, start_row, rows_per_thread, imax, bounds};
threads[thread_id] = try std.Thread.spawn(.{}, processThread, args);
}
for (threads[0..cpus]) |thread| {
thread.join();
}
return beam.make(colours, .{ .as = .binary });
}
// mutate the "colours" slice
fn processThread(
colours: []u8,
start_row: usize,
rows_per_thread: usize,
imax: usize,
bounds: Bounds
) void {
var idx: usize = start_row;
const end = start_row + rows_per_thread;
while (idx < end) : (idx += 1) {
for (0..bounds.cols) |j| {
const point = mapPixel(.{ idx, j }, bounds);
const iterNumber = iterationNumber(point, imax);
const colour = createRgba(iterNumber, imax);
const colour_idx = (idx * bounds.cols + j) * 4;
@memcpy(colours[colour_idx..colour_idx + 4], &colour);
}
}
}
fn iterationNumber(p: Point, imax: usize) ?usize {
if (p.x > 0.6 or p.x < -2.1) return null;
if (p.y > 1.2 or p.y < -1.2) return null;
// first cardiod
if ((p.x + 1) * (p.x + 1) + p.y * p.y < 0.0625) return null;
var x2: f64 = 0;
var y2: f64 = 0;
var w: f64 = 0;
for (0..imax) |j| {
if (x2 + y2 > 4) return j;
const x: f64 = x2 - y2 + p.x;
const y: f64 = w - x2 - y2 + p.y;
x2 = x * x;
y2 = y * y;
w = (x + y) * (x + y);
}
return null;
}
fn createRgba(iter: ?usize, imax: usize) [4]u8 {
// If it didn't escape, return black
if (iter == null) return [_]u8{ 0, 0, 0, 255 };
if (iter.? < imax and iter.? > 0) {
const i = iter.? % 16;
return switch (i) {
0 => [_]u8{ 66, 30, 15, 255 },
1 => [_]u8{ 25, 7, 26, 255 },
2 => [_]u8{ 9, 1, 47, 255 },
3 => [_]u8{ 4, 4, 73, 255 },
4 => [_]u8{ 0, 7, 100, 255 },
5 => [_]u8{ 12, 44, 138, 255 },
6 => [_]u8{ 24, 82, 177, 255 },
7 => [_]u8{ 57, 125, 209, 255 },
8 => [_]u8{ 134, 181, 229, 255 },
9 => [_]u8{ 211, 236, 248, 255 },
10 => [_]u8{ 241, 233, 191, 255 },
11 => [_]u8{ 248, 201, 95, 255 },
12 => [_]u8{ 255, 170, 0, 255 },
13 => [_]u8{ 204, 128, 0, 255 },
14 => [_]u8{ 153, 87, 0, 255 },
15 => [_]u8{ 106, 52, 3, 255 },
else => [_]u8{ 0, 0, 0, 255 },
};
}
return [_]u8{ 0, 0, 0, 255 };
}
// canvas-(i,j) are rows,cols and become (x,y) 2D-coordinates within the bounds
fn mapPixel(pixel: [2]usize, ctx: Bounds) Point {
const x: f64 = ctx.topLeft.x + @as(f64, @floatFromInt(pixel[1])) / @as(f64, @floatFromInt(ctx.cols)) * (ctx.bottomRight.x - ctx.topLeft.x);
const y: f64 = ctx.topLeft.y - @as(f64, @floatFromInt(pixel[0])) / @as(f64, @floatFromInt(ctx.rows)) * (ctx.topLeft.y - ctx.bottomRight.y);
return Point{ .x = x, .y = y };
}
"""
use Kino.JS
use Kino.JS.Live
@h 1000
@w 1200
@imax 500
@factor 5
# pass the zoom factor via a dataset this time
def canvas(h, w, factor) do
"""
"""
end
def start(), do: Kino.JS.Live.new(__MODULE__, canvas(@h, @w, @factor))
@impl true
def init(html, ctx) do
ctx = assign(ctx, %{max_iter: @imax, h: @h, w: @w})
{:ok, assign(ctx, html: html)}
end
@impl true
def handle_connect(ctx) do
{:ok, ctx.assigns.html, ctx}
end
@impl true
def handle_event("clicked", %{"boundary" => %{"tl" => tl, "br"=> br}}, ctx) do
%{assigns: %{h: h, w: w, max_iter: max_iter}} = ctx
IO.inspect(calc_dims(tl,br), label: "dims: ")
colours = ZigZoom.generate(h, w, max_iter, tl, br)
broadcast_event(ctx,"new", {:binary, %{}, colours})
{:noreply, ctx}
end
defp calc_dims(tl, br) do
{Enum.at(br,0)-Enum.at(tl,0), Enum.at(tl,1)-Enum.at(br,1)}
end
asset "main.js" do
"""
export function init(ctx, html) {
ctx.root.innerHTML = html;
const canvas = document.getElementById("myCanvas");
const canvasCtx = canvas.getContext("2d");
const canvasWidth = canvas.width, canvasHeight = canvas.height;
const zoomRatio = canvas.dataset.zoom;
// initial bounds
let boundary = {tl: [-2.01, 1.01], br: [1.01,-1.01]}
// Calcuate new bounds from clicked points with zoom factor
function recalculateBounds(canvasX, canvasY, bounds, options = {}) {
const {tl, br} = bounds;
const { zoomIn = true} = options;
const oldWidth = br[0] - tl[0];
const oldHeight = tl[1] - br[1];
const new_center_x = tl[0] + (canvasX / (canvasWidth-1)) * (oldWidth);
const new_center_y = tl[1] - (canvasY / (canvasHeight-1)) * (oldHeight);
const factor = zoomIn ? 1/zoomRatio : zoomRatio;
const newWidth = oldWidth * factor;
const newHeight = oldHeight * factor;
const new_boundary = {
tl: [new_center_x - newWidth/2, new_center_y + newHeight/2],
br: [new_center_x + newWidth/2, new_center_y - newHeight/2]
}
Object.assign(boundary, new_boundary);
return new_boundary;
}
// received from server -----------------------------------
ctx.handleEvent("new", ([_, binary]) => {
const imageData = new ImageData(
new Uint8ClampedArray(binary),
canvasWidth,
canvasHeight
);
// Render to canvas
canvasCtx.clearRect(0,0,canvasWidth, canvasHeight);
createImageBitmap(imageData).then(bitmap =>
canvasCtx.drawImage(bitmap, 0, 0)
);
});
// push from browser -------------------------------------
canvas.addEventListener("click", (e) => {
//if (!e.ctrlKey) return; // <-- click + ctrl pressed to zoom it
const rect = canvas.getBoundingClientRect();
console.log(rect)
const canvasX = e.clientX - rect.left;
const canvasY = e.clientY - rect.top;
console.log(e.clientX, e.clientY, rect.left, rect.top, canvasX, canvasY)
const newBounds = recalculateBounds(canvasX, canvasY, boundary, {zoomIn: !e.shiftKey});
//-- send data to the server -->
ctx.pushEvent("clicked", {boundary: newBounds})
});
// first render---------------
ctx.pushEvent("clicked", {boundary})
}
"""
end
end
ZigZoom.start()
Static image with Zig and Nx and Kino
The image is still calculated by Zig but we use Nx to consume the binary data and Kino to render it (as a B64 string, less performant).
defmodule Zigit do
use Zig, otp_app: :zigler,
#nifs: [..., generate_mandelbrot: [:threaded]],
release_mode: :fast
~Z"""
const beam = @import("beam");
const std = @import("std");
const Cx = std.math.Complex(f64);
const topLeft = Cx{ .re = -2.1, .im = 1.2 };
const bottomRight = Cx{ .re = 0.6, .im = -1.2 };
const w = bottomRight.re - topLeft.re;
const h = bottomRight.im - topLeft.im;
const Context = struct {res_x: usize, res_y: usize, imax: usize};
/// nif: generate_mandelbrot/3 Threaded
pub fn generate_mandelbrot(res_x: usize, res_y: usize, max_iter: usize) !beam.term {
const pixels = try beam.allocator.alloc(u8, res_x * res_y * 3);
defer beam.allocator.free(pixels);
const resolution = Context{ .res_x = res_x, .res_y = res_y, .imax = max_iter };
const res = try createBands(pixels, resolution);
return beam.make(res, .{ .as = .binary });
}
// <--- threaded version
fn createBands(pixels: []u8, ctx: Context) ![]u8 {
const cpus = try std.Thread.getCpuCount();
var threads = try beam.allocator.alloc(std.Thread, cpus);
defer beam.allocator.free(threads);
// half of the total rows
const rows_to_process = ctx.res_y / 2 + ctx.res_y % 2;
// one band is one count of cpus
// const nb_rows_per_band = rows_to_process / cpus + rows_to_process % cpus;
const rows_per_band = (rows_to_process + cpus - 1) / cpus;
for (0..cpus) |cpu_count| {
const start_row = cpu_count * rows_per_band;
// Stop if there are no rows to process
if (start_row >= rows_to_process) break;
const end_row = @min(start_row + rows_per_band, rows_to_process);
const args = .{ ctx, pixels, start_row, end_row };
threads[cpu_count] = try std.Thread.spawn(.{}, processRows, args);
}
for (threads[0..cpus]) |thread| {
thread.join();
}
return pixels;
}
fn processRows(ctx: Context, pixels: []u8, start_row: usize, end_row: usize) void {
for (start_row..end_row) |current_row| {
processRow(ctx, pixels, current_row);
}
}
fn processRow(ctx: Context, pixels: []u8, row_id: usize) void {
// Calculate the symmetric row
const sym_row_id = ctx.res_y - 1 - row_id;
if (row_id <= sym_row_id) {
// loop over columns
for (0..ctx.res_x) |col_id| {
const c = mapPixel(.{ @as(usize, @intCast(row_id)), @as(usize, @intCast(col_id)) }, ctx);
const iter = iterationNumber(c, ctx.imax);
const colour = createRgb(iter, ctx.imax);
const p_idx = (row_id * ctx.res_x + col_id) * 3;
@memcpy(pixels[p_idx..p_idx+2], &colour);
// Process the symmetric row (if it's different from current row)
if (row_id != sym_row_id) {
const sym_p_idx = (sym_row_id * ctx.res_x + col_id) * 3;
@memcpy(pixels[sym_p_idx..sym_p_idx+2], &colour);
}
}
}
}
fn mapPixel(pixel: [2]usize, ctx: Context) Cx {
const px_width = ctx.res_x - 1;
const px_height = ctx.res_y - 1;
const scale_x = w / @as(f64, @floatFromInt(px_width));
const scale_y = h / @as(f64, @floatFromInt(px_height));
const re = topLeft.re + scale_x * @as(f64, @floatFromInt(pixel[1]));
const im = topLeft.im + scale_y * @as(f64, @floatFromInt(pixel[0]));
return Cx{ .re = re, .im = im };
}
fn iterationNumber(c: Cx, imax: usize) ?usize {
if (c.re > 0.6 or c.re < -2.1) return 0;
if (c.im > 1.2 or c.im < -1.2) return 0;
// first cardiod
if ((c.re + 1) * (c.re + 1) + c.im * c.im < 0.0625) return null;
var z = Cx{ .re = 0.0, .im = 0.0 };
for (0..imax) |j| {
if (sqnorm(z) > 4) return j;
z = Cx.mul(z, z).add(c);
}
return null;
}
fn sqnorm(z: Cx) f64 {
return z.re * z.re + z.im * z.im;
}
fn createRgb(iter: ?usize, imax: usize) [3]u8 {
// If it didn't escape, return black
if (iter == null) return [_]u8{ 0, 0, 0 };
// Normalize time to [0,1[ now that we know it isn't "null"
const normalized = @as(f64, @floatFromInt(iter.?)) / @as(f64, @floatFromInt(imax));
if (normalized < 0.5) {
const scaled = normalized * 2;
return [_]u8{ @as(u8, @intFromFloat(255 * (1 - scaled))), @as(u8, @intFromFloat(255.0 * (1 - scaled / 2))), @as(u8, @intFromFloat(127 + 128 * scaled)) };
} else {
const scaled = (normalized - 0.5) * 2.0;
return [_]u8{ 0, @as(u8, @intFromFloat(127 * (1 - scaled / 2))), @as(u8, @intFromFloat(255 * (1 - scaled))) };
}
}
"""
end
We run the Zig code. It returns a binary that we are able to consume with Nx and display the image.
To draw an image of 1.5M pixels, it takes a few milliseconds. Although the image is written as Base64, it is still fast.
h = 1200; w = 1500
max_iter = 300;
Zigit.generate_mandelbrot(w, h, max_iter)
|> Nx.from_binary(:u8)
|> Nx.reshape({h, w, 3})
|> Kino.Image.new()