Powered by AppSignal & Oban Pro

Mandelbrot set with Elixir Nx and Zigler

livebook/mandelbrot.livemd

Mandelbrot set with Elixir Nx and Zigler

Mix.install(
  [
    {:nx, "~> 0.9"},
    {:exla, "~> 0.9"},
    {:kino, "~> 0.14.2"},
    {:zigler, "~> 0.13.3"},
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Nx.Defn.global_default_options(compiler: EXLA, client: :host)

Warning

If you run into an error above, this means that you don’t have Zig. Comment the package Zigler above. This means you can’t run the last modules where we use inline Zig code.

:observer.start()

Table of contents

  1. Introduction
  2. Play with orbits
  3. Def and Defn
  4. Algorithm
  5. Orbit and iteration number
  6. Some colour palettes
  7. Computing Mandelbrot set
  8. Parallelise it
  9. ExZoom in and out
  10. Zoom in with Zig
  11. Static image with Zig and Nx and Kino

Introduction

We want to produce an image that represents the beautiful Mandelbrot set. With one module below, you can zoom into the fractal.

Source:

We use the Nx library with the EXLA backend to speed up the computations.

We also propose to run the equivalent code in Zig in this Livebook if you want extra speed. This happens thanks to the Zigler library. The Zig code returns a binary that Nx is able to consume and Kino to display.

What is a Mandelbrot set?

In a “Mandlebrot image”, each pixel has a colour repesenting how fast the underlying point “escapes” when calculating its iterates under a certain function.

What is an underlying point?

A pixel has some coordinates [i,j]. For example, in a 1024 × 768 image (WIDTH x HEIGHT), the row number varies from from 0 to 1023 and column number from 0 to 767.

We transform these couples of integers (i,j) into a point into a 2D plane. This map “quantitizes” the 2D plane.

Here, the 2D “real” plan is defined by the upper left corner, say (-2,1), and bottom right corner, say (1,-1).

How? We have a linear mapping between the couple (i,j) and a point (x,y) in the defined zone. For example, the pixel (0,0) becomes (-2,1) and the pixel (767, 1023) becomes (1,-1).

What is iterating?

We will iterate the function: x -› x*x +c where c is a given number and x the variable.

We start with:

z0 = f(0) = c
z1 = f(z0) = z0 * z0 + c
z2 = f(z1) = z1 * z1 + c
...

Let’s take an example. The module below calculates the iterations x(n) = f(x(n-1)) by a simple recursion.

The sets of these iterates of c is called its orbit .

defmodule Simple do
  def p(x,c), do: x**2 + c

  def iterate(0,c), do: c
  def iterate(n,c), do: p(iterate(n-1, c), c)
end

We calculate the first elements of its orbit and evaluate how does the point c=1 behaves. It looks like it will diverge to infinity.

c = 1
{ c,
  Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c), Simple.iterate(5,c), Simple.iterate(6,c),
}

gives:

{1, 1, 2, 5, 26, 677, 458330}

On the other side, the point c=-1 seems well bahaved: the orbit has only two values, 0 and - 1, and is periodic.

c = -1

{ c,
  Simple.iterate(1,c), Simple.iterate(2,c), Simple.iterate(3,c), Simple.iterate(4,c),Simple.iterate(5,c),Simple.iterate(6,c),
}

gives:

{-1, -1, 0, -1, 0, -1, 0}

In the examples above, we took a simple “real” number.

For the Mandelbrot set, we use the complex representation of a point: (x,y) -> x + y*i where i is the imaginary number, Nx.Constants.i.

So, each pixel (i,j) is mapped to a complex number c = projection(i,j), and we want to evaluate how do the iterates of c behave under this iteration starting at z0 = 0.

Iteration number?

We are interested by assigning an iteration number to each c.

The number of iterations needs to be bounded (think of a periodic orbit). Let max_iter be the maximum number of iterations, for example 100.

If the orbit of c remains bounded, we assign an iteration number to max_iter.

If it escapes, meaning one iterate has a norm greater than 2, then we calculate the first index such that the iterate norm is greater than 2 (in absolute value as a complex, or its norm as a point).

If we inspect the orbits, we discover ncie figures. The Livebook below lets you play with the plan and displays orbits on-click.

Play with orbits

Livebook to compute orbits:

Run in Livebook

Def and Defn

We will use two types of functions:

  • Elixir functions using def
  • Nx functions using defn; these use a special backend (EXLA with CPU or GPU if any). Within a numerical function, all the arguments are treated as tensors. To pass a non-tensor argment, use an opts keyword.

The points of the 2D plane will be represented as complex numbers as the Mandelbrot map works with complex numbers.

The function z(n+1) = z(n) * z(n) + c takes a complex number and returns a complex number.

Algorithm

Source:

Input: image dimensions (eg w x h of 1500 x 1000), max iteration (eg 100)

Iterate over each pixel (i,j):

  • map it into the 2D plane: compute its “complex coordinates”
  • compute the iteration number
  • compute a colour
  • Sum-up and draw from the final tensor with Kino.

An example of a zoom into an area. You can go as deep as you want with these fractals objects.

Detail mandlebrot

Pixel to complex plan mapping

This module transforms a couple (i,j) into a complex number. It is a numerical function.

> a tensor can be of type complex c64. A numerical Nx function natively understands complex numbers. Note that we use the built-in constant i().

defmodule Pixel do
  import Nx.Defn
  import Nx.Constants, only: [i: 0]

  defn map(index, {h,w}, {top_left_x, top_left_y, bottom_right_x,bottom_right_y}) do

    scale_x = (bottom_right_x-top_left_x) / (w-1)
    scale_y = (bottom_right_y-top_left_y) / (h-1)

    # building a complex typed tensor
    top_left_x + index[1] * scale_x + i() * (top_left_y + index[0] * scale_y)
  end
end

Orbit and iteration number

This module computes the iteration number for a given input c. It is a numerical function.

How does this work?

  • If |c|>2, then this point is unstable. Otherwise, we have to compute for each point whether it stays bounded or not.

  • If it is bounded, we get max_iter, otherwise a lower value.

We cannot use the recursion form we did earlier because numerical functions don’t accept several headers as plain Elixir. Instead we use the "Nx"-while loop. Note how we use the Nx versions of cond. In this context, true is the tensor 1.

defmodule Orbit do
  import Nx.Defn

  defn poly(z,c), do: z*z + c

  # for speed, we don't use the norm (square root involved), only the square of the norm
  defn sq_norm(z), do: Nx.real(Nx.conjugate(z) * z)

  defn number(c,opts) do
    max_iter = opts[:imax]
    #h = opts[:h]
    #w = opts[:w]
    condition = (Nx.real(c) +1) ** 2 + (Nx.imag(c)**2)
    cond do
      # points in first cardioid are all stable. Save on iterations
      Nx.less(condition, 0.0625) ->
        max_iter
      # these points are unbounded whenever the norm is > 2
      Nx.greater(sq_norm(c), 4) ->
        0
      # evaluate each point as it can be bounded or not in the disk 2
      1 ->
        
          # less efficient going up as we need to build a h x w zero tensor
          #z0 = Nx.complex(0.0, 0.0) |> Nx.broadcast({h,w}) |> Nx.vectorize([:rows, :cols])
          # {i, _, _, _} =
            #while {i=0, z=z0, c, max_iter}, Nx.less(i,max_iter) and Nx.less(sq_norm(z), 4) do
            #    {i+1, poly(z,c), c, max_iter}
            #end

          # more efficient going down as we can start at the input value "c"
          {i, _, _} =
            while {i=max_iter, z=c, c}, Nx.greater(i,1) and Nx.less(sq_norm(z), 4) do
                {i-1, poly(z,c), c}
            end
          #i
          max_iter - i + 1
      end
    end
end




#Orbit.number(Nx.complex(0.6, 0.2), imax: 50)
#Nx.complex(0.2, 0.2) |> Nx.broadcast({400,500})

Some colour palettes

Each iteration number is an integer n. We want to associate a colour [r(n),g(n),b(n)].

This will help us to visualise which point of the complex plane is stable, and if not how fast it escapes.

The choice below is just an example. The first render uses the rbg function whilst the second the rgb2. Other choices can be made.

defmodule Colour do
  import Nx.Defn

  defn rgb(n, max_iter) do
    n = n / max_iter
    cond do
      Nx.equal(n, 0) ->
        Nx.stack([255, 255, 0])
      Nx.less(n, 0.5) ->
        s = n * 2
        r = 255 * (1 - s)
        g = 255 * (1 - s/2)
        b = 127 + 128 * s
        Nx.stack([r, g, b])
      true ->
        s = 2 * n - 1
        r = 0
        g = 127 * (1 - s)
        b = 255 * (1 - s)
         Nx.stack([r, g, b])
    end
  end

  defn rgb3(n, max_iter) do
    v = Nx.remainder(n,16)
    cond do
      Nx.equal(n, max_iter) -> Nx.stack([0,0,0])
      Nx.equal(v,0) -> Nx.stack([66, 30, 15])
      Nx.equal(v,1) -> Nx.stack([25, 7, 26])
      Nx.equal(v,2) -> Nx.stack([9, 1, 47])
      Nx.equal(v,3) -> Nx.stack([4, 4, 73])
      Nx.equal(v,4) -> Nx.stack([0, 7, 100])
      Nx.equal(v,5) -> Nx.stack([12, 44, 138])
      Nx.equal(v,6) -> Nx.stack([24, 82, 177])
      Nx.equal(v,7) -> Nx.stack([57, 125, 209])
      Nx.equal(v,8) -> Nx.stack([134, 181, 229])
      Nx.equal(v,9) -> Nx.stack([211, 236, 248])
      Nx.equal(v,10) -> Nx.stack([241, 233, 191])
      Nx.equal(v,11) -> Nx.stack([248, 201, 95])
      Nx.equal(v,12) -> Nx.stack([255, 170, 0])
      Nx.equal(v,13) -> Nx.stack([204, 128, 0])
      Nx.equal(v,14) -> Nx.stack([153, 87, 0])
      Nx.equal(v,15) -> Nx.stack([106, 52, 3])
      1 -> Nx.stack([0,0,0])
    end
  end

  defn rgb4(n, max_iter) do
    v = Nx.remainder(n,16)
    cond do
      Nx.equal(n, max_iter) -> Nx.stack([0,0,0, 255])
      Nx.equal(v,0) -> Nx.stack([66, 30, 15, 255])
      Nx.equal(v,1) -> Nx.stack([25, 7, 26, 255])
      Nx.equal(v,2) -> Nx.stack([9, 1, 47, 255])
      Nx.equal(v,3) -> Nx.stack([4, 4, 73, 255])
      Nx.equal(v,4) -> Nx.stack([0, 7, 100, 255])
      Nx.equal(v,5) -> Nx.stack([12, 44, 138, 255])
      Nx.equal(v,6) -> Nx.stack([24, 82, 177, 255])
      Nx.equal(v,7) -> Nx.stack([57, 125, 209, 255])
      Nx.equal(v,8) -> Nx.stack([134, 181, 229, 255])
      Nx.equal(v,9) -> Nx.stack([211, 236, 248, 255])
      Nx.equal(v,10) -> Nx.stack([241, 233, 191, 255])
      Nx.equal(v,11) -> Nx.stack([248, 201, 95, 255])
      Nx.equal(v,12) -> Nx.stack([255, 170, 0, 255])
      Nx.equal(v,13) -> Nx.stack([204, 128, 0, 255])
      Nx.equal(v,14) -> Nx.stack([153, 87, 0, 255])
      Nx.equal(v,15) -> Nx.stack([106, 52, 3, 255])
      1 -> Nx.stack([0,0,0, 255])
    end
  end
end

Computing the Mandelbrot set

The final module

We reassemble the tensor into the desired format for Kino to consume it and display.

> Note that you want to pass arguments into a defn function that you don’t want to be treated as tensors, you need to use a keyword list or a map.

defmodule Mandelbrot do
  import Nx.Defn

  defn compute(opts) do
    top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = - 1.2;
    defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}

    h = opts[:h]
    w = opts[:w]
    max_iter = opts[:max_iter]

    # build the tensor [[0,0],, ...[0,m], [1,1]...[n,m]]. Thks to PValente
    iota_rows = Nx.iota({h}, type: :u16) |> Nx.vectorize(:rows)
    iota_cols = Nx.iota({w}, type: :u16) |> Nx.vectorize(:cols)
    cross_product = Nx.stack([iota_rows, iota_cols])

    Pixel.map(cross_product,{h,w}, defining_points)
      |> Orbit.number(imax: max_iter) #, h: h, w: w)
      |> Colour.rgb(max_iter)
      |> Nx.devectorize()
      |> Nx.reshape({h, w, 3})
      |> Nx.as_type(:u8)
  end
end

Depending on your machine, the computation below can be lengthy. If you want to simply evaluate, set low values h = w = 300 and max_iter = 100.

h = 400; w = 502
Mandelbrot.compute(h: h, w: w, max_iter: 100)
|> Kino.Image.new()

Parallelise it

When the resolution of the image increases, it is interesting to parallelise the computations.

We divide the image in horizontal bands by fixing the numbers of rows. Then we spawn as many tasks as needed.

> This is worth only if the size of the image is large enough as this comes with non negligible overhead.

Lastly, another possible optimisation is to remark that the image is symmetric when you compute the full image. You can compute half of the image and stack to it the reversed tensor. This gives an extra boost.

This is done in StreamMandelbrot.parall_sym

defmodule StreamMandelbrot do
  import Nx.Defn


  # bottom_right_y = 0 <--- we computate only plane y>0
  defn compute_half(cross_product, %{h: h, w: w, max_iter: max_iter}) do
    top_left_x = -2; top_left_y = 1.2; bottom_right_x = 0.6; bottom_right_y = 0;
    defining_points = {top_left_x, top_left_y, bottom_right_x, bottom_right_y}

    Pixel.map(cross_product,{h,w}, defining_points)
    |> Orbit.number(imax: max_iter)
    |> Colour.rgb3(max_iter)
    |> Nx.devectorize()
    |> Nx.as_type(:u8)
  end

  #--------------------------------------------------------------------------------
  # Fix the number of rows and spawn concurrently and compute half (for full image)
  #--------------------------------------------------------------------------------

  def paral_sym(%{h: h, w: w} = opts) do
    half_h = div(h,2)

    rows_per_task = 30
    half_new = half_h - rem(half_h, rows_per_task)
    nb_tasks = div(half_new, rows_per_task)
    opts = Map.merge(opts, %{h: half_new })

      tasks_half_tensor =
        for task_id <- 0.. nb_tasks-1 do
          Task.async(fn ->
            iota_rows =
              Nx.iota({rows_per_task}, type: :u16)
              |> Nx.add(task_id * rows_per_task)
              |> Nx.vectorize(:rows)
            # full width
            iota_cols =
              Nx.iota({w}, type: :u16)
              |> Nx.vectorize(:cols)
            cross_product = Nx.stack([iota_rows, iota_cols])

            Nx.Defn.jit_apply(fn t ->
              computed_rows = compute_half(t, opts)
              {task_id, computed_rows}
              end,
              [cross_product]
            )
          end)
        end

      half_tensor =
        Task.await_many(tasks_half_tensor, :infinity)
        |> Enum.sort(fn {x,_}, {y,_} -> Nx.to_number(x)<= Nx.to_number(y) end)
        |> Enum.map(&amp;(elem(&amp;1, 1)))
        |> Nx.stack()
        |> Nx.devectorize()


      flah_tensor =
        Nx.tensor(half_tensor, names: [:cpus, :rows, :cols, :rgbs])
        |> Nx.reverse(axes: [:cpus, :rows])

      Nx.stack([half_tensor, flah_tensor]) |> Nx.reshape({half_new * 2, w, 3})
  end

end

We get the expected performance boost. We are able to incrase the resolution.

> For the fun, we changed the colour palette.

h = 800; w = 1001

StreamMandelbrot.paral_sym( %{h: h, w: w, max_iter: 150})
|> Kino.Image.new()

ExZoom in and out

In this module, we are able to click on a point of the image and zoom in and dive into the fractal.

We pass binary between the browser and the Livebook used as a server. This happens thanks to Kino.JS.Live who supports using binary data.

Given some bounds of the 2D-plan within our canvas, we will pass “quantitized” coordinates of the 2D-plan. To each (i,j) of the canvas, we compute in the browser the “real” coordinates (x,y) and pass this as a binary to the Livebook. For example, when you have 600 x 800 points as floats (4 bytes per float for f32), we pass 600x800x2x4 = 3_840_000 = 3.8MB of data.

> send the binary within an ArrayBuffer container, and allocate an Float32Array view.

With Nx.from_binary(binary, :f32), we can enter the Nx world. We parallelise the computations by running Task.async whe we chunk the rows, and reassemble the resulting tensors by the task_id .

We then send back to the browser the resulting tensor with Nx.to_binary, the size becomes 600x800x3 = 1.44MB. This is because we are sending now u8 integers representing the R,G,B of each point.

We don’t want to use data as strings (data:image base 64) as this increases the weight and thus slows down the rendering.

We can transform the binary into RGBA and display it into the canvas by building an ImageData.

> use the Uint8Array container and allocate an Uint8ClampedArray view. Then use the ImageData from the Canvas API and window.createImageBitmap to fill in the canvas.

defmodule ExZoom do
  import Nx.Defn

  @h 600
  @w  800
  @imax 250
  @zoom_ratio 5
  @rows_per_task 20

  # Nx computations ------------------------------------
  defn compute(slice, %{max_iter: max_iter}) do
    Orbit.number(slice, imax: max_iter)
    |> Colour.rgb4(max_iter)
    |> Nx.devectorize()
    |> Nx.as_type(:u8)
  end

  defn to_complex(t), do: Nx.complex(t[0], t[1])

  # we want to "complexify" the points thus we need to vectorize both dimensions
  defn complexify(t, opts) do
    h = opts[:h]
    w = opts[:w]
    Nx.reshape(t,{h, w, 2})
    |> Nx.vectorize([:r, :c])
    |> to_complex()
  end
  # <--------------------------------------------------

  def consume(binary, %{h: h, w: w, max_iter: max_iter}) do

    t = Nx.from_binary(binary, :f64) |> complexify(h: h, w: w) 

    # the height is truncated as you cannot stack tensor of different shapes easily
    nb_tasks = div(h, @rows_per_task)
    opts =  %{max_iter: max_iter}

    computations =
      for task_id <- 0..nb_tasks-1 do
        Task.async(fn ->
          # we shift the start index by the number of rows already consummed
          start_row = task_id * @rows_per_task
          slice =
            Nx.devectorize(t)
            |> Nx.slice([start_row, 0], [@rows_per_task, w])
            |> Nx.vectorize([:row, :col])

          # apply our Nx pipeline to the slice
          Nx.Defn.jit_apply(fn t ->
            computed_rows = compute(t, opts)
            {task_id, computed_rows}
            end,
            [slice]
          )
          end)
      end

    binary =
      Task.await_many(computations, :infinity)
      # re-order as per the task_id we retunrned
      |> Enum.sort(fn {x,_}, {y,_} -> Nx.to_number(x)<= Nx.to_number(y) end)
      |> Enum.map(&amp;(elem(&amp;1, 1)))
      |> Nx.stack()
      # send to the browser as binary
      |> Nx.to_binary()
    # the browser needs to know the nb of rows as it may be truncated
    {nb_tasks * @rows_per_task, binary}
  end


  # Kino.JS.Live ------------------------------------------------------
  use Kino.JS
  use Kino.JS.Live

  def canvas(h, w) do
    """
    
    """
  end

  def start(), do: Kino.JS.Live.new(__MODULE__, canvas(@h, @w))

  @impl true
  def init(html, ctx) do
    ctx = assign(ctx, %{max_iter: @imax, h: @h, w: @w})
    {:ok, assign(ctx, html: html)}
  end

  @impl true
  def handle_connect(ctx) do
    broadcast_event(ctx, "zoom_ratio", %{"zoom_ratio" => @zoom_ratio})
    {:ok, ctx.assigns.html, ctx}
  end

  @impl true
  def handle_event("clicked", {:binary,info, buffer}, %{assigns: assigns} = ctx) do
    %{h: h, w: w, max_iter: max_iter} = assigns
    # dbg(byte_size(buffer))

    case info do
      %{"boundary" => %{"max_x" => right, "max_y" => top, "min_x" => left, "min_y" => bottom}} ->
        IO.inspect({right-left, top-bottom}, label: "zoom dim: ")
      _ ->
        IO.inspect({3, 1}, label: "zoom dim: ")
    end

    {nb_rows, binary} = consume(buffer, %{h: h, w: w, max_iter: max_iter})
    broadcast_event(ctx,"new", {:binary, %{nb_rows: nb_rows}, binary})
    {:noreply, ctx}
  end


  asset "main.js" do
    """
    export function init(ctx, html) {
      ctx.root.innerHTML = html;

      const canvas = document.getElementById("myCanvas");
      const canvasCtx = canvas.getContext("2d");
      const canvasWidth = canvas.width, canvasHeight = canvas.height;

      // initialize the zoom factor
      let zoomRatio = undefined;
      ctx.handleEvent("zoom_ratio", ({zoom_ratio}) => zoomRatio = zoom_ratio);

      // initial bounds
      let boundary = {min_x: -2, max_y: 1, max_x: 1, min_y: -1};

      // Function to convert canvas coordinate to complex coordinate
      function toCartesian(canvasX, canvasY, bounds) {
        const {min_x, max_x, min_y, max_y} = bounds;
        const x = min_x + (canvasX / (canvasWidth-1)) * (max_x - min_x);
        const y = max_y - (canvasY / (canvasHeight-1)) * (max_y - min_y);
        return { x, y };
      }

      function recalculateBounds(canvasX, canvasY, bounds, options = {}) {
        // Get the clicked point's corresponding 2D plane coordinates
        const { x: newCenterX, y: newCenterY } = toCartesian(canvasX, canvasY, bounds)
        const { zoomIn = true} = options;

        const {min_x, max_x, min_y, max_y} = bounds;
        // Compute the width and height of the new zoomed rectangle
        const oldWidth = max_x - min_x;
        const oldHeight = max_y - min_y;

        const factor = zoomIn ? 1/zoomRatio : zoomRatio;
        const newWidth = oldWidth * factor;
        const newHeight = oldHeight * factor;

        return {
          min_x: newCenterX - newWidth / 2,
          max_x: newCenterX + newWidth / 2,
          min_y: newCenterY - newHeight / 2,
          max_y: newCenterY + newHeight / 2
        };
      }

      // received from server -----------------------------------
      ctx.handleEvent("new", ([{nb_rows}, binary]) => {
        //const len = binary.byteLength;
        const imageData = new ImageData(
          new Uint8ClampedArray(binary),
          canvasWidth,
          nb_rows
        );


        // we send RGBA, 4 bytes so we can immediately render
        // to canvas with "createImageBitmap"
        createImageBitmap(imageData).then(bitmap =>
          canvasCtx.drawImage(bitmap, 0, 0)
        );
      });

      // push from browser -------------------------------------
      canvas.addEventListener("click", (e) => {
        //if (!e.ctrlKey) return; // <-- click + ctrl pressed to zoom it
        const rect = canvas.getBoundingClientRect();
        const canvasX = e.clientX - rect.left;
        const canvasY = e.clientY - rect.top;
        const newBounds = recalculateBounds(canvasX, canvasY, boundary, {zoomIn: !e.shiftKey});

        //-- send data to the server -->
        Object.assign(boundary, newBounds);
        const buffer = getPoints(boundary)
        ctx.pushEvent("clicked", [{boundary: newBounds}, buffer])
      });

      // compute all cartesians coordinates in the current canvas
      function getPoints(bounds) {
        const floatArray = new Float64Array(canvasWidth * canvasHeight * 2); // 2 floats per pixel, 4 bytes each

        let index = 0;
        for (let n = 0; n < canvasWidth * canvasHeight; n++) {
          const { x, y } = toCartesian(
              n % canvasWidth,              // i
              Math.floor(n / canvasWidth),  // j
              bounds
          );
          floatArray[n * 2] = x;
          floatArray[n * 2 + 1] = y;
        }
        return floatArray.buffer;
      }

      // first render---------------
      const buffer = getPoints(boundary)
      ctx.pushEvent("clicked", [{}, buffer])
    }
    """
  end
end

You can CLICK to zoom into the point of your choice. The zoom factor is 5 (you can change it via the module argument).

You can dezoom with +SHIFT+CLICK.

ExZoom.start()

Zoom in with Zig

We can provide extra speed by using Zig code in Elixir within a Livebook.

The Livebook orchestrates between the client Javascript and the Zig snippet.

Zigler offers a remarkable documentation.

!! You may to have Zig installed on your machine.

In the Livebook, we add the dependencies (in the first cell):

Mix.install([{:zigler, "~> 0.13.3"},{:zig_get, "~> 0.13.1"},])

With the Zigler, you can inline Zig code as below.

The code below runs the same algorithm and runs OS threads (bands of rows) for concurrency.

> we use the beam memory allocator from the library.

We return the data from Zig as a binary to Elixir that will pass it to Javascript.

To get extra speed and less bandwidth consumption, the Livebook will pass it to the browser into an ImageData container and render into the canvas via window.createImageBitmap.

You can CLICK into the image and explore with the power of Zig (SHIFT+CLICK do “de-zoom”). We use a factor 5.

We pass the boundary of the canvas (in the 2D-plan coordiantes) to Zig.

The Zig NIF will compute by spawning OS threads (as much as the numbers of CPU cores).

It will return the RGBA values for each point of the canvas as a binary.

defmodule ZigZoom do
  use Zig, otp_app: :zigler,
   # nifs: [..., generate: [:threaded]]
    release_mode: :fast


  ~Z"""
  const beam = @import("beam");
  const std = @import("std");

  const Point = struct { x: f64, y: f64 };
  const point_size = @sizeOf(Point);
  const Bounds = struct { topLeft: Point, bottomRight: Point, cols: usize, rows: usize };


  /// nif: generate/5 Threaded
  pub fn generate(
      rows: usize,
      cols: usize,
      imax: usize,
      topLeft: [2]f64,
      bottomRight: [2]f64
  ) !beam.term {
    const colours = try beam.allocator.alloc(u8, rows * cols * 4);
    defer beam.allocator.free(colours);

    const cpus = try std.Thread.getCpuCount();
    var threads = try beam.allocator.alloc(std.Thread, cpus);
    defer beam.allocator.free(threads);

    const base_rows_per_thread = rows / cpus;
    const remainder_rows = rows % cpus;

    const bounds = Bounds{
        .topLeft = Point{ .x = topLeft[0], .y = topLeft[1] },
        .bottomRight = Point{ .x = bottomRight[0], .y = bottomRight[1] },
        .cols = cols,
        .rows = rows,
    };

    for (0..cpus) |thread_id| {
      const start_row = thread_id * base_rows_per_thread;
      const extra_rows = if (thread_id == cpus - 1) remainder_rows else 0;
      const rows_per_thread = base_rows_per_thread + extra_rows;

      const args = .{colours, start_row, rows_per_thread, imax, bounds};
      threads[thread_id] = try std.Thread.spawn(.{}, processThread, args);
    }

    for (threads[0..cpus]) |thread| {
      thread.join();
    }

    return beam.make(colours, .{ .as = .binary });
  }

  // mutate the "colours" slice
  fn processThread(
      colours: []u8,
      start_row: usize,
      rows_per_thread: usize,
      imax: usize,
      bounds: Bounds
  ) void {
    var idx: usize = start_row;
    const end = start_row + rows_per_thread;

    while (idx < end) : (idx += 1) {
      for (0..bounds.cols) |j| {
        const point = mapPixel(.{ idx, j }, bounds);
        const iterNumber = iterationNumber(point, imax);
        const colour = createRgba(iterNumber, imax);
        const colour_idx = (idx * bounds.cols + j) * 4;
        @memcpy(colours[colour_idx..colour_idx + 4], &colour);
      }
    }
  }

  fn iterationNumber(p: Point, imax: usize) ?usize {
    if (p.x > 0.6 or p.x < -2.1) return null;
    if (p.y > 1.2 or p.y < -1.2) return null;
    // first cardiod
    if ((p.x + 1) * (p.x + 1) + p.y * p.y < 0.0625) return null;

    var x2: f64 = 0;
    var y2: f64 = 0;
    var w: f64 = 0;

    for (0..imax) |j| {
      if (x2 + y2 > 4) return j;
      const x: f64 = x2 - y2 + p.x;
      const y: f64 = w - x2 - y2 + p.y;
      x2 = x * x;
      y2 = y * y;
      w = (x + y) * (x + y);
    }
    return null;
}

  fn createRgba(iter: ?usize, imax: usize) [4]u8 {
    // If it didn't escape, return black
    if (iter == null) return [_]u8{ 0, 0, 0, 255 };

    if (iter.? < imax and iter.? > 0) {
      const i = iter.? % 16;
      return switch (i) {
        0 => [_]u8{ 66, 30, 15, 255 },
        1 => [_]u8{ 25, 7, 26, 255 },
        2 => [_]u8{ 9, 1, 47, 255 },
        3 => [_]u8{ 4, 4, 73, 255 },
        4 => [_]u8{ 0, 7, 100, 255 },
        5 => [_]u8{ 12, 44, 138, 255 },
        6 => [_]u8{ 24, 82, 177, 255 },
        7 => [_]u8{ 57, 125, 209, 255 },
        8 => [_]u8{ 134, 181, 229, 255 },
        9 => [_]u8{ 211, 236, 248, 255 },
        10 => [_]u8{ 241, 233, 191, 255 },
        11 => [_]u8{ 248, 201, 95, 255 },
        12 => [_]u8{ 255, 170, 0, 255 },
        13 => [_]u8{ 204, 128, 0, 255 },
        14 => [_]u8{ 153, 87, 0, 255 },
        15 => [_]u8{ 106, 52, 3, 255 },
        else => [_]u8{ 0, 0, 0, 255 },
      };
    }
    return [_]u8{ 0, 0, 0, 255 };
  }

  // canvas-(i,j) are rows,cols and become (x,y) 2D-coordinates within the bounds
  fn mapPixel(pixel: [2]usize, ctx: Bounds) Point {
    const x: f64 = ctx.topLeft.x + @as(f64, @floatFromInt(pixel[1])) / @as(f64,  @floatFromInt(ctx.cols)) * (ctx.bottomRight.x - ctx.topLeft.x);
    const y: f64 = ctx.topLeft.y - @as(f64, @floatFromInt(pixel[0])) / @as(f64,  @floatFromInt(ctx.rows)) * (ctx.topLeft.y - ctx.bottomRight.y);
    return Point{ .x = x, .y = y };
  }
  """


  use Kino.JS
  use Kino.JS.Live

  @h 1000
  @w 1200
  @imax 500
  @factor 5

  # pass the zoom factor via a dataset this time
  def canvas(h, w, factor) do
    """
    
    """
  end

  def start(), do: Kino.JS.Live.new(__MODULE__, canvas(@h, @w, @factor))

  @impl true
  def init(html, ctx) do
    ctx = assign(ctx, %{max_iter: @imax, h: @h, w: @w})
    {:ok, assign(ctx, html: html)}
  end

  @impl true
  def handle_connect(ctx) do
    {:ok, ctx.assigns.html, ctx}
  end

  @impl true
  def handle_event("clicked", %{"boundary" => %{"tl" => tl, "br"=> br}}, ctx) do
    %{assigns: %{h: h, w: w, max_iter: max_iter}} = ctx
    IO.inspect(calc_dims(tl,br), label: "dims: ")
    colours = ZigZoom.generate(h, w, max_iter, tl, br)
    broadcast_event(ctx,"new", {:binary, %{}, colours})
    {:noreply, ctx}
  end

  defp calc_dims(tl, br) do
    {Enum.at(br,0)-Enum.at(tl,0), Enum.at(tl,1)-Enum.at(br,1)}
  end


  asset "main.js" do
    """
    export function init(ctx, html) {
      ctx.root.innerHTML = html;

      const canvas = document.getElementById("myCanvas");
      const canvasCtx = canvas.getContext("2d");
      const canvasWidth = canvas.width, canvasHeight = canvas.height;
      const zoomRatio = canvas.dataset.zoom;

      // initial bounds
      let boundary = {tl: [-2.01, 1.01], br: [1.01,-1.01]}

      // Calcuate new bounds from clicked points with zoom factor
      function recalculateBounds(canvasX, canvasY, bounds, options = {}) {
        const {tl, br} = bounds;
        const { zoomIn = true} = options;
        const oldWidth = br[0] - tl[0];
        const oldHeight = tl[1] - br[1];
        const new_center_x = tl[0] + (canvasX / (canvasWidth-1)) * (oldWidth);
        const new_center_y = tl[1] - (canvasY / (canvasHeight-1)) * (oldHeight);

        const factor = zoomIn ? 1/zoomRatio : zoomRatio;
        const newWidth = oldWidth * factor;
        const newHeight = oldHeight * factor;

        const new_boundary = {
          tl: [new_center_x - newWidth/2, new_center_y + newHeight/2],
          br: [new_center_x + newWidth/2, new_center_y - newHeight/2]
        }

        Object.assign(boundary, new_boundary);
        return new_boundary;
      }

      // received from server -----------------------------------
      ctx.handleEvent("new", ([_, binary]) => {
        const imageData = new ImageData(
          new Uint8ClampedArray(binary),
          canvasWidth,
          canvasHeight
        );

        // Render to canvas
        canvasCtx.clearRect(0,0,canvasWidth, canvasHeight);
        createImageBitmap(imageData).then(bitmap =>
          canvasCtx.drawImage(bitmap, 0, 0)
        );
      });


      // push from browser -------------------------------------
      canvas.addEventListener("click", (e) => {
        //if (!e.ctrlKey) return; // <-- click + ctrl pressed to zoom it
        const rect = canvas.getBoundingClientRect();
        console.log(rect)
        const canvasX = e.clientX - rect.left;
        const canvasY = e.clientY - rect.top;
        console.log(e.clientX, e.clientY, rect.left, rect.top, canvasX, canvasY)
        const newBounds = recalculateBounds(canvasX, canvasY, boundary, {zoomIn: !e.shiftKey});

        //-- send data to the server -->
        ctx.pushEvent("clicked", {boundary: newBounds})
      });

      // first render---------------
      ctx.pushEvent("clicked", {boundary})
    }
    """
  end
end
ZigZoom.start()

Static image with Zig and Nx and Kino

The image is still calculated by Zig but we use Nx to consume the binary data and Kino to render it (as a B64 string, less performant).

defmodule Zigit do
  use Zig, otp_app: :zigler,
    #nifs: [..., generate_mandelbrot: [:threaded]],
    release_mode: :fast

  ~Z"""
    const beam = @import("beam");
    const std = @import("std");
    const Cx = std.math.Complex(f64);

    const topLeft = Cx{ .re = -2.1, .im = 1.2 };
    const bottomRight = Cx{ .re = 0.6, .im = -1.2 };
    const w = bottomRight.re - topLeft.re;
    const h = bottomRight.im - topLeft.im;

    const Context = struct {res_x: usize, res_y: usize, imax: usize};

    /// nif: generate_mandelbrot/3 Threaded
    pub fn generate_mandelbrot(res_x: usize, res_y: usize, max_iter: usize) !beam.term {
        const pixels = try beam.allocator.alloc(u8, res_x * res_y * 3);
        defer beam.allocator.free(pixels);

        const resolution = Context{ .res_x = res_x, .res_y = res_y, .imax = max_iter };

        const res = try createBands(pixels, resolution);
        return beam.make(res, .{ .as = .binary });
    }

    // <--- threaded version
    fn createBands(pixels: []u8, ctx: Context) ![]u8 {
        const cpus = try std.Thread.getCpuCount();
        var threads = try beam.allocator.alloc(std.Thread, cpus);
        defer beam.allocator.free(threads);

        // half of the total rows
        const rows_to_process = ctx.res_y / 2 + ctx.res_y % 2;
        // one band is one count of cpus
        // const nb_rows_per_band = rows_to_process / cpus + rows_to_process % cpus;
        const rows_per_band = (rows_to_process + cpus - 1) / cpus;

        for (0..cpus) |cpu_count| {
            const start_row = cpu_count * rows_per_band;

            // Stop if there are no rows to process
            if (start_row >= rows_to_process) break;

            const end_row = @min(start_row + rows_per_band, rows_to_process);
            const args = .{ ctx, pixels, start_row, end_row };
            threads[cpu_count] = try std.Thread.spawn(.{}, processRows, args);
        }
        for (threads[0..cpus]) |thread| {
            thread.join();
        }

        return pixels;
    }

    fn processRows(ctx: Context, pixels: []u8, start_row: usize, end_row: usize) void {
        for (start_row..end_row) |current_row| {
            processRow(ctx, pixels, current_row);
        }
    }

    fn processRow(ctx: Context, pixels: []u8, row_id: usize) void {
        // Calculate the symmetric row
        const sym_row_id = ctx.res_y - 1 - row_id;

        if (row_id <= sym_row_id) {
            // loop over columns
            for (0..ctx.res_x) |col_id| {
                const c = mapPixel(.{ @as(usize, @intCast(row_id)), @as(usize, @intCast(col_id)) }, ctx);
                const iter = iterationNumber(c, ctx.imax);
                const colour = createRgb(iter, ctx.imax);

                const p_idx = (row_id * ctx.res_x + col_id) * 3;
                @memcpy(pixels[p_idx..p_idx+2], &colour);

                // Process the symmetric row (if it's different from current row)
                if (row_id != sym_row_id) {
                    const sym_p_idx = (sym_row_id * ctx.res_x + col_id) * 3;
                    @memcpy(pixels[sym_p_idx..sym_p_idx+2], &colour);
                }
            }
        }
    }

    fn mapPixel(pixel: [2]usize, ctx: Context) Cx {
        const px_width = ctx.res_x - 1;
        const px_height = ctx.res_y - 1;
        const scale_x = w / @as(f64, @floatFromInt(px_width));
        const scale_y = h / @as(f64, @floatFromInt(px_height));

        const re = topLeft.re + scale_x * @as(f64, @floatFromInt(pixel[1]));
        const im = topLeft.im + scale_y * @as(f64, @floatFromInt(pixel[0]));
        return Cx{ .re = re, .im = im };
    }

    fn iterationNumber(c: Cx, imax: usize) ?usize {
        if (c.re > 0.6 or c.re < -2.1) return 0;
        if (c.im > 1.2 or c.im < -1.2) return 0;

        // first cardiod
        if ((c.re + 1) * (c.re + 1) + c.im * c.im < 0.0625) return null;

        var z = Cx{ .re = 0.0, .im = 0.0 };
        for (0..imax) |j| {
            if (sqnorm(z) > 4) return j;
            z = Cx.mul(z, z).add(c);
        }
        return null;
    }

    fn sqnorm(z: Cx) f64 {
        return z.re * z.re + z.im * z.im;
    }

    fn createRgb(iter: ?usize, imax: usize) [3]u8 {
        // If it didn't escape, return black
        if (iter == null) return [_]u8{ 0, 0, 0 };

        // Normalize time to [0,1[ now that we know it isn't "null"
        const normalized = @as(f64, @floatFromInt(iter.?)) / @as(f64, @floatFromInt(imax));

        if (normalized < 0.5) {
            const scaled = normalized * 2;
            return [_]u8{ @as(u8, @intFromFloat(255 * (1 - scaled))), @as(u8, @intFromFloat(255.0 * (1 - scaled / 2))), @as(u8, @intFromFloat(127 + 128 * scaled)) };
        } else {
            const scaled = (normalized - 0.5) * 2.0;
            return [_]u8{ 0, @as(u8, @intFromFloat(127 * (1 - scaled / 2))), @as(u8, @intFromFloat(255 * (1 - scaled))) };
        }
    }


  """
end

We run the Zig code. It returns a binary that we are able to consume with Nx and display the image.

To draw an image of 1.5M pixels, it takes a few milliseconds. Although the image is written as Base64, it is still fast.

h = 1200; w = 1500
max_iter = 300;

Zigit.generate_mandelbrot(w, h, max_iter)
|> Nx.from_binary(:u8)
|> Nx.reshape({h, w, 3})
|> Kino.Image.new()