Powered by AppSignal & Oban Pro

Inference-time Wrappers: BestOfN, Refine, Ensemble, MultiChainComparison

inference_time_wrappers.livemd

Inference-time Wrappers: BestOfN, Refine, Ensemble, MultiChainComparison

Mix.install(
  [
    {:dsxir, path: Path.expand("../..", __DIR__)},
    {:sycophant, "~> 0.4"},
    {:kino, "~> 0.19"}
  ]
)

Overview

Compilation (Dsxir.compile/5) spends budget before deployment to pick better prompts and demos. Inference-time wrappers spend budget per call instead: they trade extra LM calls at request time for a more reliable answer. You reach for them when a single shot is too unreliable and you have a way to tell good answers from bad ones.

dsxir ships four, in two shapes:

Wrapper Shape What it does Needs
BestOfN called from forward/2 sample the program n times, keep the best a reward function
Refine called from forward/2 BestOfN + reflect on failures and retry with advice a reward function
Ensemble called from forward/2 run N programs concurrently, reduce the results a reducer (e.g. majority vote)
MultiChainComparison declared predictor compare M pre-generated reasoning chains into one answer the chains, as :completions

The first three are not declared predictors. You call them directly from a forward/2 body, the same way you would call Dsxir.Predictor.Parallel. MultiChainComparison is the odd one out: it is a declared predictor, so it goes in a predictor declaration and is invoked through call/3.

When run from a checkout of dsxir, Mix.install/1 above resolves the library from the parent directory. If you launch this livebook from elsewhere, replace the path: line with the dsxir version.

Configuring the LM

api_key_input = Kino.Input.password("OPENAI_API_KEY")

Most of these wrappers need diversity across attempts, so we run at a non-zero temperature. A single helper builds the LM config from the password input; pass temperature: per section as needed.

lm = fn temperature ->
  api_key = Kino.Input.read(api_key_input)

  [
    lm:
      {Dsxir.LM.Sycophant,
       [model: "openai:gpt-4o-mini", api_key: api_key, temperature: temperature]},
    adapter: Dsxir.Adapter.Chat
  ]
end
#Function<42.113135111/1 in :erl_eval.expr/6>

A task that rewards a second look

We need a task where one shot is genuinely unreliable but a correct answer is mechanically checkable — otherwise the wrappers have nothing to optimize toward. The bat-and-ball question from the Cognitive Reflection Test fits perfectly: the answer is $0.05, but the fast intuitive answer ($0.10) is wrong, and models trip over it often enough to make repeated sampling pay off.

defmodule Crt.Question do
  use Dsxir.Signature

  signature do
    instruction """
    Solve the arithmetic word problem. Many such problems have an intuitive
    answer that is wrong on closer inspection — check your arithmetic before
    committing.
    """

    input :question, :string

    output :answer, :string,
      desc: "The final answer as a bare dollar amount, e.g. $0.05 (no quotes)."
  end
end
{:module, Crt.Question, <<70, 79, 82, 49, 0, 0, 108, ...>>, ...}

A Chain-of-Thought program over that signature. Because ChainOfThought prepends a :reasoning output, each prediction carries both the worked steps (pred.fields.reasoning) and the final pred.fields.answer — we use both later.

defmodule Crt.Reason do
  use Dsxir.Module

  predictor :solve, Dsxir.Predictor.ChainOfThought, signature: Crt.Question

  def forward(prog, %{question: q}) do
    call(prog, :solve, %{question: q})
  end
end
{:module, Crt.Reason, <<70, 79, 82, 49, 0, 0, 82, ...>>, ...}

The reward function is the whole point: it is what lets BestOfN and Refine recognize a good attempt. It takes the program inputs and the produced %Dsxir.Prediction{} and returns a number. Here we strip the formatting off the answer and score an exact match against the gold value.

defmodule Crt.Reward do
  @accepted ~w(0.05 .05 5cents 5cent)

  @spec correct?(map(), Dsxir.Prediction.t()) :: float()
  def correct?(_inputs, %Dsxir.Prediction{fields: %{answer: answer}}) do
    if normalize(answer) in @accepted, do: 1.0, else: 0.0
  end

  defp normalize(answer) when is_binary(answer) do
    answer
    |> String.downcase()
    |> String.replace(["$", " ", "\"", "'"], "")
  end
end
{:module, Crt.Reward, <<70, 79, 82, 49, 0, 0, 11, ...>>, ...}
question =
  "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. " <>
    "How much does the ball cost?"
"A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?"

BestOfN

Dsxir.Predictor.BestOfN.run/4 runs the program up to :n times with diverse sampling, scores each result with the reward function, and returns the highest-scoring {program, prediction}. Pass :threshold to stop early the moment an attempt clears the bar — once we see a correct answer, there is no point spending the remaining calls.

Dsxir.context(lm.(1.0), fn ->
  {_program, prediction} =
    Dsxir.Predictor.BestOfN.run(
      Dsxir.Program.new(Crt.Reason),
      %{question: question},
      &amp;Crt.Reward.correct?/2,
      n: 5,
      threshold: 1.0
    )

  %{answer: prediction[:answer], reasoning: prediction[:reasoning]}
end)
%{
  reasoning: "Let's denote the cost of the ball as \\( x \\) dollars. According to the problem, the bat costs $1.00 more than the ball, which we can express as \\( x + 1.00 \\) dollars.\n\nThe total cost of the bat and the ball together is $1.10. Therefore, we can set up the following equation:\n\n\\[\nx + (x + 1.00) = 1.10\n\\]\n\nSimplifying this:\n\n\\[\n2x + 1.00 = 1.10\n\\]\n\nNow, we subtract 1.00 from both sides:\n\n\\[\n2x = 1.10 - 1.00\n\\]\n\\[\n2x = 0.10\n\\]\n\nNext, we divide both sides by 2:\n\n\\[\nx = \\frac{0.10}{2}\n\\]\n\\[\nx = 0.05\n\\]\n\nThus, the ball costs $0.05.\n\nNow, let's verify:\n- The ball costs $0.05.\n- The bat costs \\( 0.05 + 1.00 = 1.05 \\).\n- The total cost is \\( 0.05 + 1.05 = 1.10 \\) which is correct.\n\nTherefore, the ball costs $0.05.",
  answer: "$0.05"
}

What happened: up to five independent samples at temperature: 1.0. The first to score 1.0 short-circuits the rest; if none clears the threshold, you get the best of the five. The recognized opts:

  • :n (required) — number of attempts.
  • :threshold — stop early once reward >= threshold. Default nil (run all n, return the best).
  • :fail_count — raising attempts tolerated before the error re-raises. Default n.
  • :temperature — per-attempt sampling temperature. Default 1.0.

BestOfN is the right tool when you can score an answer cheaply but cannot fix it — verifiable formats, constraint satisfaction, anything with a unit test or a regex.

Refine

Dsxir.Predictor.Refine.run/4 is BestOfN plus reflection. After a sub-threshold attempt it runs an internal OfferFeedback predictor over the captured trace and injects per-predictor corrective advice into the next attempt through the Dsxir.Settings :hints channel. So instead of five independent guesses, each retry is told what went wrong last time.

Dsxir.context(lm.(1.0), fn ->
  {_program, prediction} =
    Dsxir.Predictor.Refine.run(
      Dsxir.Program.new(Crt.Reason),
      %{question: question},
      &amp;Crt.Reward.correct?/2,
      n: 3,
      threshold: 1.0
    )

  %{answer: prediction[:answer], reasoning: prediction[:reasoning]}
end)
%{
  reasoning: "Let's denote the cost of the ball as \\(x\\). According to the problem:\n\n1. The bat costs $1.00 more than the ball, which means the bat costs \\(x + 1.00\\).\n2. Together, the bat and ball cost $1.10.\n\nSetting up the equation based on this information, we have:\n\\[\nx + (x + 1.00) = 1.10\n\\]\n\nNow simplify the equation:\n\\[\n2x + 1.00 = 1.10\n\\]\n\nNext, subtract $1.00 from both sides:\n\\[\n2x = 1.10 - 1.00\n\\]\n\\[\n2x = 0.10\n\\]\n\nNow, divide both sides by 2 to solve for \\(x\\):\n\\[\nx = 0.05\n\\]\n\nThus, the ball costs $0.05, and we confirm that the bat, costing $1.00 more, is $1.05. Together they sum to $1.10, confirming our solution is correct.",
  answer: "$0.05"
}

Same call signature and opts as BestOfN, plus one extra:

  • :feedback_lm — a {module, config} tuple overriding the LM used for the OfferFeedback call. Use a stronger model to critique a cheaper generator.

With threshold: nil, every sub-final attempt is refined, so each attempt after the first builds on the previous one’s critique. Refine costs more per attempt than BestOfN (the feedback call is an extra LM round-trip), so it earns its keep when attempts are expensive and a blind re-roll is unlikely to stumble onto the fix — the model needs to be told about the $1.00-difference trap, not just asked again.

Ensemble

Dsxir.Predictor.Ensemble.run/3 takes a list of programs, runs them concurrently over the same inputs, and reduces the survivors. Unlike the reward-sampling wrappers, it does not score attempts — it combines them. The built-in Dsxir.Predictor.Ensemble.majority/2 reducer votes; ties break by first occurrence.

Here the “ensemble” is the same Chain-of-Thought program sampled five times. At temperature: 1.0 each run reasons independently, and the majority answer is more robust than any single roll.

Dsxir.context(lm.(1.0), fn ->
  members = List.duplicate(Dsxir.Program.new(Crt.Reason), 5)

  prediction =
    Dsxir.Predictor.Ensemble.run(
      members,
      %{question: question},
      reduce_fn: &amp;Dsxir.Predictor.Ensemble.majority(&amp;1, field: :answer)
    )

  prediction[:answer]
end)
"$0.05"

Members fan out under Dsxir.TaskSupervisor, each replaying the caller’s settings snapshot. The recognized opts:

  • :reduce_fn([Dsxir.Prediction.t()] -> Dsxir.Prediction.t()). When nil (default), run/3 returns the raw prediction list; with a reducer, the single reduced prediction.
  • :size — sample this many programs per call (default: all).
  • :max_concurrency — default System.schedulers_online().
  • :timeout — per-member timeout in ms (default 30_000).

Member failures are tolerated and surfaced via telemetry; the reduction runs over the survivors. Only when every member fails does run/3 raise Dsxir.Errors.Framework.PredictorError with reason: :all_failed.

Two honest caveats. First, the return shape differs from the rest of the predictor namespace: a bare %Dsxir.Prediction{} (with a reducer) or a [%Dsxir.Prediction{}] (without one), never a {program, prediction} tuple — there is no single program identity across N members. Second, voting only helps when the members are right more often than not. If the model is systematically pulled toward the wrong $0.10 answer, majority vote will faithfully return the wrong answer. Ensemble reduces variance, not bias — pair it with diverse members (different demos, different models, different temperatures) when bias is the worry.

Without a reducer you get the raw list and can inspect the spread yourself:

Dsxir.context(lm.(1.0), fn ->
  members = List.duplicate(Dsxir.Program.new(Crt.Reason), 5)

  Dsxir.Predictor.Ensemble.run(members, %{question: question})
  |> Enum.map(&amp; &amp;1[:answer])
end)
["$0.05", "$0.05", "$0.05"]

MultiChainComparison

The first three wrappers all generate attempts. MultiChainComparison does not: you hand it M reasoning chains you already produced, and it makes one comparison call that integrates them into a single best answer. It is a declared predictor, so it lives in a predictor declaration and is called through call/3 — with the chains passed under the :completions input key.

First, generate the chains. Any source works; here we sample the Chain-of-Thought program four times, keeping each full prediction (so both :reasoning and :answer ride along).

completions =
  Dsxir.context(lm.(1.0), fn ->
    reason_prog = Dsxir.Program.new(Crt.Reason)

    for _ <- 1..4 do
      {_prog, pred} = Crt.Reason.forward(reason_prog, %{question: question})
      pred
    end
  end)

Enum.map(completions, &amp; &amp;1[:answer])
["$0.05", "$0.05", "$0.05", "$0.05"]

Now the comparison module. The predictor uses the base signature (Crt.Question); MultiChainComparison augments it internally with one reasoning_attempt_i input per chain plus a prepended rationale output. M is derived from the number of completions you pass, not declared.

defmodule Crt.Compare do
  use Dsxir.Module

  predictor :compare, Dsxir.Predictor.MultiChainComparison,
    signature: Crt.Question

  def forward(prog, %{question: q, completions: completions}) do
    call(prog, :compare, %{question: q, completions: completions})
  end
end
{:module, Crt.Compare, <<70, 79, 82, 49, 0, 0, 83, ...>>, ...}
Dsxir.context(lm.(0.0), fn ->
  prog = Dsxir.Program.new(Crt.Compare)

  {_prog, prediction} =
    Crt.Compare.forward(prog, %{question: question, completions: completions})

  %{answer: prediction[:answer], rationale: prediction[:rationale]}
end)
%{
  answer: "$0.05",
  rationale: "All four reasoning attempts correctly define the cost of the ball as \\( x \\) and the cost of the bat as \\( x + 1.00 \\). They all set up the equation based on the total cost of $1.10 and simplify it correctly to find \\( x = 0.05 \\). Each attempt verifies the solution by checking the total cost, confirming that the calculations are accurate. Therefore, the consistent conclusion across all attempts is that the ball costs $0.05."
}

Each completion renders into the prompt as a candidate attempt — “I’m trying to . I’m not sure but my prediction is “ — and the LM is instructed to compare them and produce the single most consistent answer. The extra rationale output captures why it picked what it picked.

Note the contrast with Ensemble: voting is mechanical and cannot recover from a chain none of the members got right, but it is cheap (no extra LM call to reduce). MultiChainComparison spends one more LM call to let the model reason about the disagreement — it can notice that three chains made the same arithmetic slip and side with the fourth. Run the comparison call itself at temperature: 0.0: you want a decisive integration, not another diverse sample.

A completion can be a %Dsxir.Prediction{} or a plain map carrying the output field(s) and an optional :reasoning. A chain with no :reasoning simply renders without the “I’m trying to” clause:

Dsxir.context(lm.(0.0), fn ->
  prog = Dsxir.Program.new(Crt.Compare)

  bare = [
    %{answer: "$0.05"},
    %{answer: "$0.10"},
    %{answer: "$0.05"}
  ]

  {_prog, prediction} =
    Crt.Compare.forward(prog, %{question: question, completions: bare})

  prediction[:answer]
end)
"$0.05"

Watching the wrappers work

Each wrapper emits telemetry under [:dsxir, :predictor, , ...]. Attach a handler before a call to see the attempts as they happen — the same events you would forward to your observability pipeline in production.

ref =
  :telemetry_test.attach_event_handlers(self(), [
    [:dsxir, :predictor, :best_of_n, :attempt],
    [:dsxir, :predictor, :best_of_n, :stop]
  ])

Dsxir.context(lm.(1.0), fn ->
  Dsxir.Predictor.BestOfN.run(
    Dsxir.Program.new(Crt.Reason),
    %{question: question},
    &amp;Crt.Reward.correct?/2,
    n: 5,
    threshold: 1.0
  )
end)

drain = fn drain, acc ->
  receive do
    {event, ^ref, measurements, _meta} ->
      drain.(drain, [{List.last(event), measurements} | acc])
  after
    0 -> Enum.reverse(acc)
  end
end

events = drain.(drain, [])
:telemetry.detach(ref)
events
[attempt: %{reward: 1.0}, stop: %{best_reward: 1.0}]

The event names per wrapper:

  • BestOfN[:dsxir, :predictor, :best_of_n, :attempt] (one per sample) and [:dsxir, :predictor, :best_of_n, :stop] (carries best_reward).
  • Refine[:dsxir, :predictor, :refine, :attempt] and :stop.
  • Ensemble[:dsxir, :predictor, :ensemble, :member] (one per program, with ok?/failed?) and [:dsxir, :predictor, :ensemble, :stop] (with members_run and successes).

Every event merges the active Dsxir.Settings :metadata, so a Dsxir.context([metadata: %{tenant_id: ...}], ...) block tags each attempt — exactly what you want for per-tenant cost dashboards.

Choosing between them

  • Can you score an answer but not fix it? BestOfN. Cheapest path to reliability when a verifier exists.
  • Are attempts expensive and a blind retry unlikely to help? Refine. Pay for a critique so the next attempt is informed, not random.
  • Want to combine independent runs and reduce variance? Ensemble. Remember it reduces variance, not bias — diversify the members.
  • Already have several reasoning chains and want the model to reconcile them? MultiChainComparison. One extra call to reason about the disagreement rather than vote on it.

These are not mutually exclusive. A common production shape: Ensemble over several ChainOfThought rolls to gather diverse chains, then feed those chains to MultiChainComparison for a reasoned final answer — or wrap the whole forward/2 in BestOfN against a verifier on top.