Powered by AppSignal & Oban Pro

Mandelbrot orbits with Nx and Kino.Live.JS

livebook/orbits.livemd

Mandelbrot orbits with Nx and Kino.Live.JS

Mix.install(
  [
    {:nx, "~> 0.9.1"},
    {:exla, "~> 0.9.1"},
    {:kino_vega_lite, "~> 0.1.11"},
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)
# Nx.Defn.global_default_options(compiler: EXLA, client: :host)

Orbit number

We return a tensor consisting of each iterates (points) of a given point $c$ under the map:

$p: z \mapsto z^2 +c$

starting at $p(0)=c$.

We allocate a tensor $t_0$ of length $n$, the max number of calculated iterates with $t_0=[c,0,\dots]$

At each step:

  • while the squared norm of the iterate is less than 4 and max iterates is not reached, we add the iterate to the tensor,
  • else we stop.
defmodule Orbit do
  import Nx.Defn

  defn p(z,c), do: z*z + c

  defn sq_norm(z), do: Nx.real(z * Nx.conjugate(z))
  
  defn calc(c, opts) do
    n = opts[:n]

    t0 = Nx.broadcast(0, {n}) |> Nx.put_slice([0], Nx.reshape(c, {1}))
      
    while {i=0, t=t0, c, n}, Nx.less(i,n) and Nx.less(sq_norm(t[i]), 4) do
      t_i_plus_1 = Nx.indexed_put(t, Nx.stack([i+1]), p(t[i], c))
      {i+1, t_i_plus_1, c, n}
    end
  end
end

Plotting orbits

Given a point $c$, we return a tuple {interation_number, [[x_0,y_0], [x_1,y_1],...} for the Kino.Live.JS to send it to the Javascript to plot it.

defmodule Plot do
  import Nx.Defn
  import Nx.Constants, only: [i: 0]

  defn points(cx,cy,opts) do
    c = cx + i() * cy
    {t_nb, t, _, _} = Orbit.calc(c, opts)
    {t_nb, t}
  end
  
  def to_js(cx, cy, imax) do
    {t_nb, t} = points(cx,cy, [n: imax]) 

    nb = Nx.to_number(t_nb) - 1
    
    data = 
      t[0..nb]
      |> Nx.to_list()
      |> Enum.map(fn z -> [Complex.real(z), Complex.imag(z)] end)

    {nb, data}
  end
end

The module below is a Kino.Js.Live module to interact between the browser and the Livebook as a server.

When you click on the 2D plan, it will plot teh orbit of this point.

More precisely, when you click on the plan:

  • the browser sends the coordinates to Kino.Live.JS (modulo a transformation canvas to 2D-plan),
  • the Kino.Live.JS server calls the Plot.point module that calculates the orbit (via Orbit.calc)
  • then Kino.Live.JS sends the reuslt to the browser to plot it with a little animation.

The client server communication uses the primitives handle_event and broadcast server side, and pushEvent and handleEvent clent side.

defmodule LiveOrbit do
  use Kino.JS
  use Kino.JS.Live

  def canvas() do
    """
    <p>Clicked point: <span id="coordinates"></span>, &nbsp Iteration number: <span id="number"></span></p>
    <canvas id="myCanvas" width="800" height="600"></canvas>
    """
  end

  def run(), do: Kino.JS.Live.new(__MODULE__, canvas())

  @impl true
  def init(html, ctx) do
    ctx = assign(ctx, %{imax: 150})
    {:ok, assign(ctx, html: html)}
  end

  @impl true
  def handle_connect(ctx) do
    {:ok, ctx.assigns.html, ctx}
  end

  #-------------- received from client
  @impl true
  def handle_event("clicked",  %{"x"=> x, "y" => y} = _evt, ctx) do
    %{assigns: %{imax: imax}} = ctx
    {nb, data} = Plot.to_js(x, y, imax)
    # Nx.to_list(data) |> dbg()
    
    # send data to the client
    broadcast_event(ctx, "points", %{data: data})
    broadcast_event(ctx, "number", %{nb: nb})
    
    {:noreply, ctx}
  end

  asset "main.js" do
    """
    export function init(ctx, html) {
      ctx.root.innerHTML = html;

      const canvas = document.getElementById("myCanvas");
      const canvasCtx = canvas.getContext("2d");
      const coordsElement = document.getElementById("coordinates");
      const iterationNumber = document.getElementById("number")
  
      // Canvas dimensions
      const canvasWidth = canvas.width;
      const canvasHeight = canvas.height;
  
      // Coordinate plane bounds
      const minX = -1.5;
      const maxX = 1;
      const minY = -1;
      const maxY = 1;

      // Function to convert canvas coordinate to complex coordinate
      function toComplexCoord(canvasX, canvasY) {
        const x = minX + ((canvasX / canvasWidth) * (maxX - minX));
        const y = maxY - ((canvasY / canvasHeight) * (maxY - minY));
        return { x, y };
      }
    
      // Function to convert compelx coordinate to canvas coordinate
      function toCanvasCoord(x, y) {
        const canvasX = ((x - minX) / (maxX - minX)) * canvasWidth;
        const canvasY = canvasHeight - ((y - minY) / (maxY - minY)) * canvasHeight;
        return { x: canvasX, y: canvasY };
      }
    
      // Draw x-axis
      canvasCtx.beginPath();
      canvasCtx.moveTo(0, canvasHeight * (Math.abs(maxY) / (Math.abs(maxY) + Math.abs(minY))));
      canvasCtx.lineTo(canvasWidth, canvasHeight * (Math.abs(maxY) / (Math.abs(maxY) + Math.abs(minY))));
      canvasCtx.strokeStyle = 'black';
      canvasCtx.stroke();
  
      // Draw y-axis
      canvasCtx.beginPath();
      canvasCtx.moveTo(canvasWidth * (Math.abs(minX) / (Math.abs(minX) + Math.abs(maxX))), 0);
      canvasCtx.lineTo(canvasWidth * (Math.abs(minX) / (Math.abs(minX) + Math.abs(maxX))), canvasHeight);
      canvasCtx.strokeStyle = 'black';
      canvasCtx.stroke();

      // Draw grid dynamically based on min/max X/Y
      const gridStep = 0.2;
      canvasCtx.strokeStyle = '#ccc';
      canvasCtx.setLineDash([2, 4]);
    
      for (let x = minX; x <= maxX; x += gridStep) {
        const { x: canvasX } = toCanvasCoord(x, 0);
        canvasCtx.beginPath();
        canvasCtx.moveTo(canvasX, 0);
        canvasCtx.lineTo(canvasX, canvasHeight);
        canvasCtx.stroke();
      }
      for (let y = minY; y <= maxY; y += gridStep) {
        const { y: canvasY } = toCanvasCoord(0, y);
        canvasCtx.beginPath();
        canvasCtx.moveTo(0, canvasY);
        canvasCtx.lineTo(canvasWidth, canvasY);
        canvasCtx.stroke();
      }
    
      canvasCtx.setLineDash([]);


      //---------- client event: send click position --------------------
      canvas.addEventListener("click", (e) => {
        const rect = canvas.getBoundingClientRect();
        const x = e.clientX - rect.left;
        const y = e.clientY - rect.top;
        const { x: complexX, y: complexY } = toComplexCoord(x, y);
        coordsElement.textContent = `${complexX.toFixed(2)}, ${complexY.toFixed(2)}`;
    
        //-- send data to the server -->
        ctx.pushEvent("clicked", {x: Number(complexX), y: Number(complexY)})
      });

      //----------- callbacks from server after computations--------------
      ctx.handleEvent("number", ({nb}) => {iterationNumber.textContent = nb});
      ctx.handleEvent("points", drawPointsWithAnimation)

      // Draw each point with animation
      function drawPointsWithAnimation({data: points}) {
        console.log(points)
        let index = 0;
        // loop with "setInterval"
        const intervalId = setInterval(() => {
          if (index >= points.length) {
            clearInterval(intervalId); // Stop the animation when all points are drawn
            return;
          }
          animatePoint(points, index)
          index++;
        }, 100); 
      }

      function animatePoint(points, index) {
        const currPoint = toCanvasCoord(points[index][0], points[index][1])

        // Draw the point as a small circle
        canvasCtx.beginPath();
        const radius = index === 0 ? 6 : 3;
        canvasCtx.arc(currPoint.x, currPoint.y, radius, 0, 2 * Math.PI);
        canvasCtx.fillStyle = index === 0 ? "green" : 'red';
        canvasCtx.fill();
    
        // Draw a dotted line connecting to the previous point, if exists
        if (index > 0) {
          const prevPoint = toCanvasCoord(points[index - 1][0], points[index - 1][1]);
          canvasCtx.beginPath();
          canvasCtx.moveTo(prevPoint.x, prevPoint.y);
          canvasCtx.lineTo(currPoint.x, currPoint.y);
          canvasCtx.strokeStyle = 'blue';
          canvasCtx.setLineDash([5, 5]); // Creates a dotted line
          canvasCtx.lineWidth = 1;
          canvasCtx.stroke();
          canvasCtx.setLineDash([]); // Reset line dash
        }
      }
    }
    """
  end
end

You can visualize the orbits of any point by clicking into the plot below.

LiveOrbit.run()