Hierarchical clustering
app_root = Path.join(__DIR__, "..")
Mix.install(
[
{:kino, "~> 0.10.0"},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:scholar, path: app_root}
],
config_path: Path.join(app_root, "config/config.exs"),
lockfile: Path.join(app_root, "mix.lock")
)
Setup
Let’s configure EXLA
as our default backend (where our tensors are stored) and compiler (which compiles Scholar code) across the notebook and all branched sections:
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)
Introduction
defmodule Scholar.Kino.CanvasDendrogram do
use Kino.JS
def new(graph), do: Kino.JS.new(__MODULE__, graph)
asset "main.js" do
"""
function scale_fun(from_range, to_range) {
let [from_min, from_max] = from_range;
let [to_min, to_max] = to_range;
return function(x) {
return to_min + (x - from_min) / (from_max - from_min) * (to_max - to_min)
}
}
function text_height(ctx, x) {
let text_metrics = ctx.measureText(x);
return text_metrics.actualBoundingBoxAscent + text_metrics.actualBoundingBoxDescent;
}
function text_width(ctx, x) {
let text_metrics = ctx.measureText(x);
return text_metrics.actualBoundingBoxRight + text_metrics.actualBoundingBoxLeft;
}
function draw(ctx, canvas_params, dendrogram) {
let height = canvas_params.height;
let width = canvas_params.width;
let clades = dendrogram.clades
let dissimilarities = dendrogram.dissimilarities
let num_leaves = dendrogram.num_points
let max_dissimilarity = dissimilarities[dissimilarities.length - 1];
let x_min = 0;
let x_max = num_leaves - 1;
let y_min = 0;
let y_max = max_dissimilarity;
let x_tick_labels = [...Array(x_max + 1).keys()];
let y_tick_labels = [...Array(Math.floor(y_max) + 1).keys()];
let x_tick_label_height = Math.max(...x_tick_labels.map((l) => text_height(ctx, l)));
let y_tick_label_width = Math.max(...y_tick_labels.map((l) => text_width(ctx, l)));
let x_tick_area_height = 10;
let y_tick_area_width = 10;
let margin = 10;
let plot_left = margin + y_tick_label_width + y_tick_area_width;
let plot_top = margin;
let plot_height = height - 2*margin - x_tick_label_height - x_tick_area_height;
let plot_width = width - 2*margin - y_tick_label_width - y_tick_area_width;
let x_range = x_max - x_min;
let y_range = y_max - y_min;
let data_margin = 0.1;
let x_data_min = x_min - x_range * data_margin / 2;
let x_data_max = x_max + x_range * data_margin / 2;
let y_data_min = y_min - y_range * data_margin / 2;
let y_data_max = y_max + y_range * data_margin / 2;
let scale_x = scale_fun([x_data_min, x_data_max], [plot_left, plot_left + plot_width]);
let scale_y = scale_fun([y_data_min, y_data_max], [plot_top + plot_height, plot_top]);
// Axes
ctx.beginPath();
ctx.moveTo(plot_left, plot_top);
ctx.lineTo(plot_left, plot_top + plot_height);
ctx.lineTo(plot_left + plot_width, plot_top + plot_height);
ctx.lineTo(plot_left + plot_width, plot_top);
ctx.lineTo(plot_left, plot_top);
ctx.stroke();
ctx.closePath();
// x-ticks
for(let x of x_tick_labels) {
ctx.beginPath();
ctx.moveTo(scale_x(x), scale_y(y_data_min));
ctx.lineTo(scale_x(x), scale_y(y_data_min) + x_tick_area_height / 2);
ctx.stroke();
ctx.closePath();
}
// y-ticks
for(let y of y_tick_labels) {
ctx.beginPath();
ctx.moveTo(scale_x(x_data_min), scale_y(y));
ctx.lineTo(scale_x(x_data_min) - y_tick_area_width / 2, scale_y(y));
ctx.stroke();
ctx.closePath();
}
// x-tick labels
ctx.textAlign = "center";
for(let x of x_tick_labels) {
ctx.strokeText(x, scale_x(x), scale_y(y_data_min) + x_tick_area_height + x_tick_label_height);
}
// y-tick labels
ctx.textBaseline = "middle";
ctx.textAlign = "end";
for(let y of y_tick_labels) {
ctx.strokeText(y, scale_x(x_data_min) - y_tick_area_width, scale_y(y));
}
// Leaves
let coords = new Map();
for (let i = 0; i < num_leaves; i++) {
let x = scale_x(i);
let y = scale_y(0);
ctx.beginPath();
ctx.arc(x, y, 5, 0, Math.PI * 2);
ctx.fill();
coords.set(i, [x, y]);
}
// Clades
for (let i = 0; i < clades.length; i++) {
let [a, b] = clades[i]
let c = i + num_leaves
let d = dissimilarities[i]
let [ax, ay] = coords.get(a);
let [bx, by] = coords.get(b);
let cx = (ax + bx) / 2;
let cy = scale_y(d);
ctx.beginPath();
ctx.moveTo(ax, ay);
ctx.lineTo(ax, cy);
ctx.lineTo(bx, cy);
ctx.lineTo(bx, by);
ctx.stroke();
ctx.closePath();
ctx.beginPath();
ctx.arc(cx, cy, 5, 0, Math.PI * 2);
ctx.fill();
coords.delete(a);
coords.delete(b);
coords.set(c, [cx, cy]);
}
}
export function init(ctx, input) {
let dendrogram = input.dendrogram
let canvas_params = input.canvas
let canvas_el_id = "dendrogram-plot";
ctx.root.innerHTML =
`
`;
let canvas_el = document.getElementById(canvas_el_id);
// Check for canvas support
if (canvas_el.getContext) {
let canvas_ctx = canvas_el.getContext("2d");
draw(canvas_ctx, canvas_params, dendrogram);
}
}
"""
end
end
# 5 | 0 1 3 4
# 4 | 2 5
# 3 |
# 2 | 6
# 1 | 7 8
# 0 +-+-+-+-+-+
# 0 1 2 3 4 5
# Tensor form of the data sketched above
data = Nx.tensor([[1, 5], [2, 5], [1, 4], [4, 5], [5, 5], [5, 4], [1, 2], [1, 1], [2, 1]])
# Build model from data
model = Scholar.Cluster.Hierarchical.fit(data, dissimilarity: :euclidean, linkage: :average)
# Make a JSON-serializable "dendrogram"
dendrogram =
model
|> Map.from_struct()
|> Map.new(fn
{k, %Nx.Tensor{} = v} -> {k, Nx.to_list(v)}
{k, v} -> {k, v}
end)
# Plot
Scholar.Kino.CanvasDendrogram.new(%{dendrogram: dendrogram, canvas: %{width: 400, height: 400}})