Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Hierarchical clustering

notebooks/hierarchical_clustering.livemd

Hierarchical clustering

app_root = Path.join(__DIR__, "..")

Mix.install(
  [
    {:kino, "~> 0.10.0"},
    {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
    {:scholar, path: app_root}
  ],
  config_path: Path.join(app_root, "config/config.exs"),
  lockfile: Path.join(app_root, "mix.lock")
)

Setup

Let’s configure EXLA as our default backend (where our tensors are stored) and compiler (which compiles Scholar code) across the notebook and all branched sections:

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

Introduction

defmodule Scholar.Kino.CanvasDendrogram do
  use Kino.JS

  def new(graph), do: Kino.JS.new(__MODULE__, graph)

  asset "main.js" do
    """
    function scale_fun(from_range, to_range) {
      let [from_min, from_max] = from_range;
      let [to_min, to_max] = to_range;
      return function(x) {
        return to_min + (x - from_min) / (from_max - from_min) * (to_max - to_min)
      }
    }

    function text_height(ctx, x) {
      let text_metrics = ctx.measureText(x);
      return text_metrics.actualBoundingBoxAscent + text_metrics.actualBoundingBoxDescent;
    }

    function text_width(ctx, x) {
      let text_metrics = ctx.measureText(x);
      return text_metrics.actualBoundingBoxRight + text_metrics.actualBoundingBoxLeft;
    }

    function draw(ctx, canvas_params, dendrogram) {
      let height = canvas_params.height;
      let width = canvas_params.width;

      let clades = dendrogram.clades
      let dissimilarities = dendrogram.dissimilarities
      let num_leaves = dendrogram.num_points
      let max_dissimilarity = dissimilarities[dissimilarities.length - 1];

      let x_min = 0;
      let x_max = num_leaves - 1;
      let y_min = 0;
      let y_max = max_dissimilarity;

      let x_tick_labels = [...Array(x_max + 1).keys()];
      let y_tick_labels = [...Array(Math.floor(y_max) + 1).keys()];
      let x_tick_label_height = Math.max(...x_tick_labels.map((l) => text_height(ctx, l)));
      let y_tick_label_width = Math.max(...y_tick_labels.map((l) => text_width(ctx, l)));
      let x_tick_area_height = 10;
      let y_tick_area_width = 10;

      let margin = 10;
      let plot_left = margin + y_tick_label_width + y_tick_area_width;
      let plot_top = margin;
      let plot_height = height - 2*margin - x_tick_label_height - x_tick_area_height;
      let plot_width = width - 2*margin - y_tick_label_width - y_tick_area_width;

      let x_range = x_max - x_min;
      let y_range = y_max - y_min;
      let data_margin = 0.1;
      let x_data_min = x_min - x_range * data_margin / 2;
      let x_data_max = x_max + x_range * data_margin / 2;
      let y_data_min = y_min - y_range * data_margin / 2;
      let y_data_max = y_max + y_range * data_margin / 2;
      let scale_x = scale_fun([x_data_min, x_data_max], [plot_left, plot_left + plot_width]);
      let scale_y = scale_fun([y_data_min, y_data_max], [plot_top + plot_height, plot_top]);

      // Axes
      ctx.beginPath();
      ctx.moveTo(plot_left, plot_top);
      ctx.lineTo(plot_left, plot_top + plot_height);
      ctx.lineTo(plot_left + plot_width, plot_top + plot_height);
      ctx.lineTo(plot_left + plot_width, plot_top);
      ctx.lineTo(plot_left, plot_top);
      ctx.stroke();
      ctx.closePath();

      // x-ticks
      for(let x of x_tick_labels) {
        ctx.beginPath();
        ctx.moveTo(scale_x(x), scale_y(y_data_min));
        ctx.lineTo(scale_x(x), scale_y(y_data_min) + x_tick_area_height / 2);
        ctx.stroke();
        ctx.closePath();
      }

      // y-ticks
      for(let y of y_tick_labels) {
        ctx.beginPath();
        ctx.moveTo(scale_x(x_data_min), scale_y(y));
        ctx.lineTo(scale_x(x_data_min) - y_tick_area_width / 2, scale_y(y));
        ctx.stroke();
        ctx.closePath();
      }

      // x-tick labels
      ctx.textAlign = "center";
      for(let x of x_tick_labels) {
        ctx.strokeText(x, scale_x(x), scale_y(y_data_min) + x_tick_area_height + x_tick_label_height);
      }

      // y-tick labels
      ctx.textBaseline = "middle";
      ctx.textAlign = "end";
      for(let y of y_tick_labels) {
        ctx.strokeText(y, scale_x(x_data_min) - y_tick_area_width, scale_y(y));
      }

      // Leaves
      let coords = new Map();
      for (let i = 0; i < num_leaves; i++) {
        let x = scale_x(i);
        let y = scale_y(0);

        ctx.beginPath();
        ctx.arc(x, y, 5, 0, Math.PI * 2);
        ctx.fill();

        coords.set(i, [x, y]);
      }

      // Clades
      for (let i = 0; i < clades.length; i++) {
        let [a, b] = clades[i]
        let c = i + num_leaves
        let d = dissimilarities[i]

        let [ax, ay] = coords.get(a);
        let [bx, by] = coords.get(b);
        let cx = (ax + bx) / 2;
        let cy = scale_y(d);

        ctx.beginPath();
        ctx.moveTo(ax, ay);
        ctx.lineTo(ax, cy);
        ctx.lineTo(bx, cy);
        ctx.lineTo(bx, by);
        ctx.stroke();
        ctx.closePath();

        ctx.beginPath();
        ctx.arc(cx, cy, 5, 0, Math.PI * 2);
        ctx.fill();

        coords.delete(a);
        coords.delete(b);
        coords.set(c, [cx, cy]);
      }
    }

    export function init(ctx, input) {
      let dendrogram = input.dendrogram
      let canvas_params = input.canvas
      let canvas_el_id = "dendrogram-plot";

      ctx.root.innerHTML =
        `
        `;

      let canvas_el = document.getElementById(canvas_el_id);

      // Check for canvas support
      if (canvas_el.getContext) {
        let canvas_ctx = canvas_el.getContext("2d");
        draw(canvas_ctx, canvas_params, dendrogram);
      }
    }
    """
  end
end
# 5 | 0 1   3 4
# 4 | 2       5
# 3 |
# 2 | 6
# 1 | 7 8
# 0 +-+-+-+-+-+
#   0 1 2 3 4 5

# Tensor form of the data sketched above
data = Nx.tensor([[1, 5], [2, 5], [1, 4], [4, 5], [5, 5], [5, 4], [1, 2], [1, 1], [2, 1]])

# Build model from data
model = Scholar.Cluster.Hierarchical.fit(data, dissimilarity: :euclidean, linkage: :average)

# Make a JSON-serializable "dendrogram"
dendrogram =
  model
  |> Map.from_struct()
  |> Map.new(fn
    {k, %Nx.Tensor{} = v} -> {k, Nx.to_list(v)}
    {k, v} -> {k, v}
  end)

# Plot
Scholar.Kino.CanvasDendrogram.new(%{dendrogram: dendrogram, canvas: %{width: 400, height: 400}})