Source code for the Algorithm Branches described in Branches: A Fast Dynamic Programming and Branch & Bound Algorithm for Optimal Decision Trees .
We recommend creating a conda virtual environment from our .yml file as follows:
conda env create -f dependencies.yml
conda activate branches
To visualize Decision Trees, we need the svgling package, which is not currenlty supported by conda. Thus we install it with pip:
pip install svgling
.
├── data # Data used for benchmarking
├── src # Source files
│ ├── branch_ordinal.py # Source file for classification problems with ordinally encoded data
│ ├── branch_binary.py # Source file for binary classification problems with binary data
│ ├── branch_binary_multi.py # Source file for classification problems with binary data
│ ├── branches.py # Source file for the Branches algorithm
│ └── tutorial.ipynb # Tutorial .ipynb notebook
├── trees # SVG files of optimal decision trees
├── Tables.png # PNG file containing a summary of empricial comparisons.
├── LICENSE
└── README.md
The MONK's Problems are standard datasets for benchmarking Optimal Decision Trees algorithms. We use the first of these problems to illustrate how to use Branches.
from branches import *
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, LabelEncoder
# Reading the data
data = np.genfromtxt('data/monks-1.train', delimiter=' ', dtype=int)
data = data[:, :-1] # Getting rid of the last column, it contains only ids.
data = data[:, ::-1] # Reorder the columns to put the predicted variable Y at the end.
# Ordinal Encoding of the data
encoder = OrdinalEncoder()
encoder.fit(data)
data = encoder.transform(data).astype(int)
# Running Branches
alg = Branches(data)
alg.solve(lambd=0.01)
# Printing the accuracy, number of branches and number of splits
branches, splits = alg.lattice.infer()
print('Number of branches :', len(branches))
print('Number of splits :', splits)
print('Accuracy :', ((alg.predict(data[:, :-1]) == data[:, -1]).sum())/alg.n_total)
Using the nltk and svgling packages, we can plot the optimal Decision Tree via the code below.
tree = alg.plot_tree(show_classes=False)
svgling.draw_tree(tree)
This figure does not include the predictions at the level of the leaves. To show these, set show_classes=True
in the plot_tree method.
tree = alg.plot_tree(show_classes=True)
svgling.draw_tree(tree)
Some nodes exhibit the same subtrees, which makes this representation a little redundant. To compactify it, set compact=True
in the plot_tree method.
tree = alg.plot_tree(show_classes=True, compact=True)
svgling.draw_tree(tree)
The tutorial notebook src/tutorial.ipynb
contains more examples on how to use Branches, especially with its micro-optimisation techniques that allow for significant computational gains.
Branches optimises the regularised accuracy