Powered by AppSignal & Oban Pro

Colored E-Graphs: Context-Scoped Rewrites

notebooks/colored_egraphs.livemd

Colored E-Graphs: Context-Scoped Rewrites

Mix.install([
  {:quail, path: "/Users/quinn/dev/beam_box/quail"}
])

Overview

Quail now supports relational colored e-graphs — the Hou approach to context-scoped equalities. This lets you express conditional rewrites (like branch-specific simplifications) as standard Datalog rules scoped by context, without modifying the base e-graph.

This notebook walks through:

  1. The problem — why base e-graphs can’t express “x == 0 only in this branch”
  2. Contexts — the lattice that scopes equalities
  3. Contextual Datalog rules — rules that match and merge within a context
  4. Per-context extraction — pulling the cheapest term from a context-specific view

The Problem: Global Merges

Imagine optimizing add(x, y) where in one branch we know x == 0 and in another y == 0. A standard e-graph merge is global — once x equals num(0), it equals num(0) everywhere. There’s no way to say “only in this branch.”

alias Quail.EGraph

eg = EGraph.new()
{x, eg} = EGraph.add_term(eg, {:var, :x})
{_y, eg} = EGraph.add_term(eg, {:var, :y})
{zero, eg} = EGraph.add_term(eg, {:num, 0})
{root, eg} = EGraph.add_term(eg, {:add, {:var, :x}, {:var, :y}})

# Global merge: x == num(0) everywhere.
{_, eg} = EGraph.merge(eg, x, zero)
eg = EGraph.rebuild(eg)

{:ok, term} = Quail.Extract.extract(eg, root)
IO.puts("After global merge of x with num(0):")
IO.inspect(term, label: "  extracted")
IO.puts("\nThe original 'x' is gone — it's num(0) in EVERY context now.")

The Solution: Contextual Equalities

With colored e-graphs, we use Database.assert_equal_in/4 to scope an equality to a context. The base e-graph is never modified. Contexts form a lattice where :base is the bottom element visible everywhere, and {:assume, label} creates an independent assumption.

alias Quail.{Context, Database}

db = Quail.new()
{x, db} = Quail.add_term(db, {:var, :x})
{y, db} = Quail.add_term(db, {:var, :y})
{zero, db} = Quail.add_term(db, {:num, 0})
{root, db} = Quail.add_term(db, {:add, {:var, :x}, {:var, :y}})

# Create two branch contexts.
ctx_x0 = Context.assume(:x_is_zero)
ctx_y0 = Context.assume(:y_is_zero)

# Assert equalities scoped to their contexts.
db = Database.assert_equal_in(db, x, zero, ctx_x0)
db = Database.assert_equal_in(db, y, zero, ctx_y0)

# Base union-find is UNAFFECTED.
{cx, _} = Database.find(db, x)
{cz, _} = Database.find(db, zero)
IO.puts("Base: x and zero are #{if cx == cz, do: "SAME", else: "DIFFERENT"} classes")

# But in context x_is_zero, they're equivalent.
equiv = Database.context_equivalent_classes(db, x, ctx_x0)
IO.puts("In ctx_x0: x equivalent to #{MapSet.size(equiv)} classes (x and zero)")

# And in context y_is_zero, x is still just x.
equiv_y = Database.context_equivalent_classes(db, x, ctx_y0)
IO.puts("In ctx_y0: x equivalent to #{MapSet.size(equiv_y)} class (just x)")

Contextual Datalog Rules

The real power comes from Datalog rules scoped by context. A rule with a context: field matches against a virtual view where the contextual equalities hold, and its union_in actions create context-scoped merges.

Here we write add_zero rules that fire in each branch context:

import Quail, only: [v: 1]
alias Quail.Rule

# Build rules for each context.
# In ctx_x0: add(r, a, b) where num(a, 0) => union_in r with b
# In ctx_y0: add(r, a, b) where num(b, 0) => union_in r with a
rules =
  Enum.flat_map([ctx_x0, ctx_y0], fn ctx ->
    [
      Rule.rule(:"add_zero_r_#{inspect(ctx)}",
        context: ctx,
        body: [
          {:add, [v(:r), v(:a), v(:b)]},
          {:num, [v(:b), 0]}
        ],
        actions: [{:union_in, ctx, v(:r), v(:a)}]
      ),
      Rule.rule(:"add_zero_l_#{inspect(ctx)}",
        context: ctx,
        body: [
          {:add, [v(:r), v(:a), v(:b)]},
          {:num, [v(:a), 0]}
        ],
        actions: [{:union_in, ctx, v(:r), v(:b)}]
      )
    ]
  end)

result = Quail.run(db, rules, iter_limit: 10)
db = result.database

IO.puts("Saturation complete in #{result.iterations} iteration(s).")
IO.puts("Stop reason: #{result.stop_reason}")

Per-Context Extraction

Now we can extract the cheapest term from each context’s view of the e-graph. The extract_in_context/3 function builds a temporary virtual union-find that includes the contextual equalities, runs cost-based extraction on it, then discards it. The base e-graph is never modified.

# Base extraction: no simplification possible.
{:ok, base_term} = Quail.extract(db, root)
IO.puts("Base:   #{inspect(base_term)}")

# In ctx_x0: x == 0, so add(0, y) simplifies to y.
{:ok, x0_term} = Quail.extract_in_context(db, root, ctx_x0)
IO.puts("ctx_x0: #{inspect(x0_term)}")

# In ctx_y0: y == 0, so add(x, 0) simplifies to x.
{:ok, y0_term} = Quail.extract_in_context(db, root, ctx_y0)
IO.puts("ctx_y0: #{inspect(y0_term)}")

Each branch got its own simplification. The base expression is untouched.

Contextual Constant Folding

Here’s a more involved example: if we know a == 3 and b == 4 in a specific context, constant folding should produce 7 — but only in that context.

db2 = Quail.new()
{a, db2} = Quail.add_term(db2, {:var, :a})
{b, db2} = Quail.add_term(db2, {:var, :b})
{three, db2} = Quail.add_term(db2, {:num, 3})
{four, db2} = Quail.add_term(db2, {:num, 4})
{root2, db2} = Quail.add_term(db2, {:add, {:var, :a}, {:var, :b}})

ctx_known = Context.assume(:known_values)
db2 = Database.assert_equal_in(db2, a, three, ctx_known)
db2 = Database.assert_equal_in(db2, b, four, ctx_known)

# Constant folding rule scoped to the context.
fold_rules = [
  Rule.rule(:const_add,
    context: ctx_known,
    body: [
      {:add, [v(:r), v(:x), v(:y)]},
      {:num, [v(:x), v(:n)]},
      {:num, [v(:y), v(:m)]}
    ],
    actions: [{:union_in, ctx_known, v(:r), v(:sum_class)}],
    guard: fn %{n: n, m: m} -> %{sum_class: {:add_term, {:num, n + m}}} end
  )
]

result2 = Quail.run(db2, fold_rules, iter_limit: 10)
db2 = result2.database

{:ok, base2} = Quail.extract(db2, root2)
{:ok, folded} = Quail.extract_in_context(db2, root2, ctx_known)

IO.puts("Base:        #{inspect(base2)}")
IO.puts("In context:  #{inspect(folded)}")
IO.puts("\nConstant folding happened ONLY in the context where a==3 and b==4.")

How It Works

The implementation uses three key ideas:

Layer Mechanism
Storage ctx_equal relation: {class_a, class_b, context} tuples
Matching Query.run_in_context/3 expands bound variables to context-equivalent classes
Extraction Extract.extract_in_context/4 builds a virtual union-find overlay

The context lattice is flat: :base is below everything (visible everywhere), and {:assume, label} values are independent siblings. This covers the common case of branch assumptions in case/if expressions.

The virtual union-find approach means extraction never modifies the real database — it copies the base UF (an immutable Elixir map, so just a reference), applies contextual merges on the copy, and runs standard bottom-up cost extraction.