Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Compiled Decision Trees Benchmark

notebooks/compiled_benchmarks.livemd

Compiled Decision Trees Benchmark

Mix.install([
  {:scidata, "~> 0.1"},
  {:exgboost, "~> 0.4"},
  {:mockingjay, github: "acalejos/mockingjay"},
  {:nx, "~> 0.5", override: true},
  {:exla, "~> 0.5"},
  {:scholar, "~> 0.2"},
  {:benchee, "~> 1.0"}
])

Setup Dataset

{x, y} = Scidata.Iris.download()
data = Enum.zip(x, y) |> Enum.shuffle()
{train, test} = Enum.split(data, ceil(length(data) * 0.8))
{x_train, y_train} = Enum.unzip(train)
{x_test, y_test} = Enum.unzip(test)

x_train = Nx.tensor(x_train)
y_train = Nx.tensor(y_train)

x_test = Nx.tensor(x_test)
y_test = Nx.tensor(y_test)

Gather Model / Prediction Functions

EXGBoost.compile/1 will convert your trained Booster model into a set of tensor operations which can then be run on any of the Nx backends.

# Get Baseline Model (XGBoost C API)
model = EXGBoost.train(x_train, y_train, num_class: 3, objective: :multi_softprob)
# Get Compiled Models w/ Binary Backend
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Nx.BinaryBackend)
gemm_predict = EXGBoost.compile(model, strategy: :gemm)
gemm_jit_exla = EXLA.jit(gemm_predict)
tree_trav_predict = EXGBoost.compile(model, strategy: :tree_traversal)
tree_trav_jit_exla = EXLA.jit(tree_trav_predict)
ptt_predict = EXGBoost.compile(model, strategy: :perfect_tree_traversal)
ptt_jit_exla = EXLA.jit(ptt_predict)
# Get Compiled Models w/ EXLA Backend
Nx.Defn.default_options(compiler: EXLA)
Nx.default_backend(EXLA.Backend)
gemm_exla = EXGBoost.compile(model, strategy: :gemm)
tree_trav_exla = EXGBoost.compile(model, strategy: :tree_traversal)
ptt_exla = EXGBoost.compile(model, strategy: :perfect_tree_traversal)

funcs = %{
  "Base" => fn x -> EXGBoost.predict(model, x) end,
  "Compiled -- GEMM Strategy -- Binary Backend" => fn x -> gemm_predict.(x) end,
  "Compiled -- Tree Traversal Strategy -- Binary Backend" => fn x -> tree_trav_predict.(x) end,
  "Compiled -- Perfect Tree Traversal Strategy -- Binary Backend" => fn x -> ptt_predict.(x) end,
  "Compiled -- GEMM Strategy -- EXLA Backend" => fn x -> gemm_exla.(x) end,
  "Compiled -- Tree Traversal Strategy -- EXLA Backend" => fn x -> tree_trav_exla.(x) end,
  "Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend" => fn x -> ptt_exla.(x) end,
  "Compiled -- GEMM Strategy -- EXLA Backend (JIT)" => fn x -> gemm_jit_exla.(x) end,
  "Compiled -- Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x ->
    tree_trav_jit_exla.(x)
  end,
  "Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x ->
    ptt_jit_exla.(x)
  end
}

Run Time Benchmarks

benches = Map.new(funcs, fn {k, v} -> {k, v.(x_train)} end)

Benchee.run(benches,
  time: 10,
  memory_time: 2,
  warmup: 5
)

Compare Accuracies

Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Nx.BinaryBackend)

accuracies =
  Enum.reduce(funcs, %{}, fn {name, pred_fn}, acc ->
    accuracy =
      pred_fn.(x_test)
      |> Nx.argmax(axis: -1)
      |> then(&Scholar.Metrics.Classification.accuracy(y_test, &1))
      |> Nx.to_number()

    Map.put(acc, name, accuracy)
  end)