Plotting in EXGBoost
Mix.install([
{:exgboost, "~> 0.5"},
{:scidata, "~> 0.1"},
{:kino_vega_lite, "~> 0.1"}
])
# This assumed you launch this livebook from its location in the exgboost/notebooks folder
Introduction
Much of the utility from decision trees come from their intuitiveness and ability to inform dcisions outside of the confines of a black-box model. A decision tree can be easily translated to a series of actions that can be taken on behalf of the stakeholder to achieve the desired outcome. This makes them especially useful in business decisions, where people might still want to have the final say but be as informed as possible. Additionally, tabular data is still quite popular in the business domain, which conforms to the required input for decision trees.
Decision trees can be used for both regression and classification tasks, but classification tends to be what is most associated with decision trees.
This notebook will go over some of the details of the EXGBoost.Plotting
module, including using preconfiged styles, custom styling, as well as customizing the entire vidualization.
Plotting APIs
There are 2 main APIs exposed to control plotting in EXGBoost
:
-
Top-level API (
EXGBoost.plot_tree/2
)- Using predefined styles
- Defining custom styles
- Mix of the first 2
-
EXBoost.Plotting
module API-
Use the Vega
data
spec defined inEXGBoost.get_data_spec/2
-
Define your own Vega spec using the data from either
EXGBoost.Plotting.to_tabular/1
or some other means
We will walk through each of these in detail.
-
Use the Vega
Regardless of which API you choose to use, it is helpful to understand how the plotting module works (althought the higher-level API you choose to work with the less important it becomes).
Implementation Details
The plotting functionality provided in EXGBoost
is powered by the Vega JavaScript library and the Elixir VegaLite
library which provides the piping to interop with the JavaScript libraries. We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.
Vega is a plotting library built on top of the very powerful D3 JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a reduced schema compared to the full Vega spec. EXGBoost.Plotting
leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API.
For these reasons, unfortunately we could not just implement plotting for EXGBoost
as a composable Vega-Lite pipeline. This makes working synamically with the spec a bit more unwieldly, but much care was taken to still make the high-level plotting API extensible, and if needed you can go straight to defining your own JSON spec.
Setup Data
We will still be using the Iris dataset for this notebook, but if you want more details about the process of training and evaluating a model please check out the Iris Classification with Gradient Boosting
notebook.
So let’s proceed by setting up the Iris dataset.
{x, y} = Scidata.Iris.download()
data = Enum.zip(x, y) |> Enum.shuffle()
{train, test} = Enum.split(data, ceil(length(data) * 0.8))
{x_train, y_train} = Enum.unzip(train)
{x_test, y_test} = Enum.unzip(test)
x_train = Nx.tensor(x_train)
y_train = Nx.tensor(y_train)
x_test = Nx.tensor(x_test)
y_test = Nx.tensor(y_test)
Train Your Booster
Now go ahead and train your booster. We will use early_stopping_rounds: 1
because we’re not interested in the accuracy of the booster for this demonstration (Note that we need to set evals
to use early stopping).
You will notice that EXGBoost
also provides an implementation for Kino.Render
so that EXGBoost.Booster
s are rendered as a plot by default.
booster =
EXGBoost.train(
x_train,
y_train,
num_class: 3,
objective: :multi_softprob,
num_boost_rounds: 10,
evals: [{x_train, y_train, "training"}],
verbose_eval: false,
early_stopping_rounds: 1
)
You’ll notice that the plot doesn’t display any labels to the features in the splits, and instead only shows features labelled as “f2” etc. If you provide feature labels during training, your plot will show the splits using the feature labels.
booster =
EXGBoost.train(x_train, y_train,
num_class: 3,
objective: :multi_softprob,
num_boost_rounds: 10,
evals: [{x_train, y_train, "training"}],
verbose_eval: false,
feature_name: ["sepal length", "sepal width", "petal length", "petal width"],
early_stopping_rounds: 1
)
Top-Level API
EXGBoost.plot_tree/2
is the quickest way to customize the output of the plot.
This API uses Vega Mark
s to describe the plot. Each of the following Mark
options accepts any of the valid keys from their respective Mark
type as described in the Vega documentation.
Please note that these are passed as a Keyword
, and as such the keys must be atoms rather than strings as the Vega docs show. Valid options for this API are camel_cased
atoms as opposed to the pascalCased
strings the Vega docs describe, so if you wish to pass "fontSize"
as the Vega docs show, you would instead pass it as font_size:
in this API.
The plot is composed of the following parts:
-
Top-level keys: Options controlling parts of the plot outside of direct control of a
Mark
, such as:padding
,:autosize
, etc. Accepts any Vega top-level top-level key in addition to several specific to this API (scuh as:style
and:depth
). -
:leaves
:Mark
specifying the leaf nodes of the tree -
:splits
Mark
specifying the split (or inner / decision) nodes of the tree -
:yes
-
:no
EXGBoost.plot_tree/2
defaults to outputting a VegaLite
struct. If you pass the :path
option it will save to a file instead.
If you want to add any marks to the underlying plot you will have to use the lower-level EXGBoost.Plotting
API, as the top-level API is only capable of customizing these marks.
Top-Level Keys
EXGBoost
supports changing the direction of the plots through the :rankdir
option. Avaiable directions are [:tb, :bt, :lr, :rl]
, with top-to-bottom (:tb
) being the default.
EXGBoost.plot_tree(booster, rankdir: :bt)
By default, plotting only shows one (the first) tree, but seeing as a Booster
is really an ensemble of trees you can choose which tree to plot through the :index
option, or set to nil
to have a dropdown box to select the tree.
EXGBoost.plot_tree(booster, rankdir: :lr, index: nil)
You’ll also notice that the plot is interactive, with support for scrolling, zooming, and collapsing sections of the tree. If you click on a split node you will toggle the visibility of its descendents, and the rest of the tree will fill the canvas.
You can also use the :depth
option to programatically set the max depth to display in the tree:
EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3)
One way to affect the canvas size is by controlling the padding.
You can add padding to all side by specifying an integer for the :padding
option
EXGBoost.plot_tree(booster, rankdir: :rl, index: 4, depth: 3, padding: 50)
Or specify padding for each side:
EXGBoost.plot_tree(booster,
rankdir: :lr,
index: 4,
depth: 3,
padding: [top: 5, bottom: 25, left: 50, right: 10]
)
You can also specify the canvas size using the :width
and :height
options:
EXGBoost.plot_tree(booster,
rankdir: :lr,
index: 4,
depth: 3,
width: 500,
height: 500
)
But do note that changing the padding of a canvas does change the size, even if you specify the size using :height
and :width
EXGBoost.plot_tree(booster,
rankdir: :lr,
index: 4,
depth: 3,
width: 500,
height: 500,
padding: 10
)
You can change the dimensions of all nodes through the :node_height
and :node_width
options:
EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3, node_width: 60, node_height: 60)
Or change the space between nodes using the :space_between
option.
Note that the size of the accompanying nodes and/or text will change to accomodate the new :space_between
option while trying to maintain the canvas size.
EXGBoost.plot_tree(
booster,
rankdir: :lr,
index: 4,
depth: 3,
space_between: [nodes: 200]
)
So if you want to add the space between while not changing the size of the nodes you might need to manually adjust the canvas size:
EXGBoost.plot_tree(
booster,
rankdir: :lr,
index: 4,
depth: 3,
space_between: [nodes: 200],
height: 800
)
EXGBoost.plot_tree(
booster,
rankdir: :lr,
index: 4,
depth: 3,
space_between: [levels: 200]
)
Mark Options
The options controlling the appearance of individual marks all conform to a similar API. You can refer to the options and pre-defined defaults for a subset of the allowed options, but you can also pass other options so long as they are allowed by the Vega Mark spec (as defined here)
EXGBoost.plot_tree(
booster,
rankdir: :bt,
index: 4,
depth: 3,
space_between: [levels: 200],
yes: [
text: [font_size: 18, fill: :teal]
],
no: [
text: [font_size: 20]
],
node_width: 100
)
Most marks accept an :opacity
option that you can use to effectively hide the mark:
EXGBoost.plot_tree(
booster,
rankdir: :lr,
index: 4,
depth: 3,
splits: [
text: [opacity: 0],
rect: [opacity: 0],
children: [opacity: 1]
]
)
And text
marks accept normal text options such as :fill
, :font_size
, and :font
:
EXGBoost.plot_tree(
booster,
node_width: 250,
splits: [
text: [font: "Helvetica Neue", font_size: 20, fill: "orange"]
],
space_between: [levels: 20]
)
Styles
There are a set of provided pre-configured settings for the top-level API that you may optionally use. You can refer to the EXGBoost.Plottings.Styles
docs to see a gallery of each style in action. You can specify a style with the :style
option in EXGBoost.plot_tree/2
.
You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, but you are free to specify any other allowed keys and they will be merged with the style. Any options passed explicitly to the option does take precedence over the style options.
For example, let’s look at the :solarized_dark
style:
EXGBoost.Plotting.solarized_dark() |> Keyword.take([:background, :height]) |> IO.inspect()
EXGBoost.plot_tree(booster, style: :solarized_dark)
You can see that it defines a background color of #002b36
but does not restrict what the height must be.
EXGBoost.plot_tree(booster, style: :solarized_dark, background: "white", height: 200)
We specified both :background
and :height
here, and the background specified in the option supercedes the one from the style.
You can also always get the style specification as a Keyword
which can be passed to EXGBoost.plot_tree/2
manually, making any needed changes yourself, like so:
custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white")
EXGBoost.plot_tree(booster, style: custom_style)
You can also programatically check which styles are available:
EXGBoost.Plotting.get_styles()
Configuration
You can also set defaults for the top-level API using an Application
configuration for EXGBoost
under the :plotting
key. Since the defaults are collected from your configuration file at compile-time, anything you set during runtime, even if you set it to the Application environment, will not be registered as defaults.
For example, if you just want to change the default pre-configured style you can do:
Mix.install([
{:exgboost, path: Path.join(__DIR__, ".."), env: :dev},
],
config:
[
exgboost: [
plotting: [
style: :solarized_dark,
]]
],
lockfile: :exgboost)
You can also make one-off changes to any of the settings with this method. In effect, this turns into a default custom style. Just make sure to set style: nil
to ensure that the style
option doesn’t supercede any of your settings. Here’s an example of that:
default_style =
[
style: nil,
background: "#3f3f3f",
leaves: [
# Foreground
text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "normal"],
# Comment
rect: [fill: "#7f9f7f", stroke: "#7f9f7f"]
],
splits: [
# Foreground
text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "bold"],
# Comment
rect: [fill: "#7f9f7f", stroke: "#7f9f7f"],
# Selection
children: [fill: "#2b2b2b", stroke: "#2b2b2b"]
],
yes: [
# Green
text: [fill: "#7f9f7f"],
# Selection
path: [stroke: "#2b2b2b"]
],
no: [
# Red
text: [fill: "#cc9393"],
# Selection
path: [stroke: "#2b2b2b"]
]
]
Mix.install([
{:exgboost, path: Path.join(__DIR__, ".."), env: :dev},
],
config:
[
exgboost: [
plotting: default_style,
]
]
)
NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.
At any point, you can check what your default settings are by using EXGBoost.Plotting.get_defaults/0
EXGBoost.Plotting.get_defaults()
Low-Level API
If you find yourself needing more granular control over your plots, you can reach towards the EXGBoost.Plotting
module. This module houses the EXGBoost.Plotting.plot/2
function, which is what is used under the hood from the EXGBoost.plot_tree/2
top-level API. This module also has the get_data_spec/2
function, as well as the to_tabular/1
function, both of which can be used to specify your own Vega specification. Lastly, the module also houses all of the pre-configured styles, which are 0-arity functions which output the Keyword
s containing their respective style’s options that can be passed to the plotting APIs.
Let’s briefly go over the to_tabular/1
and get_data_spec/2
functions:
The to_tabular/1
function is used to convert a Booster
, which is formatted as a tree structure, to a tabular format which can be ingested specifically by the Vega Stratify transform. It returns a list of “nodes”, which are just Map
s with info about each node in the tree.
EXGBoost.Plotting.to_tabular(booster) |> hd
You can use this function if you want to have complete control over the visualization, and just want a bit of a head start with respect to data transformation for converting the Booster
into a more digestible format.
The get_data_source/2
function is used if you want to use the provided Vega data specification. This is for those who want to only focus on implementing your own Vega Marks, and want to leverage the data transformation pipeline that powers the top-level API.
The data transformation used is the following pipeline:
to_tabular/1
-> Filter (by tree index) -> Stratify -> Tree
EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt)
The Vega fields which are not included with get_data_spec/2
and are included in plot/2
are:
You can make a completely valid plot using only the Data from get_data_specs/2
and adding the marks you need.