Skip to content
Advertisement

How to retrieve the full branch path leading to each leaf node of a sklearn Decision Tree?

I have this decision tree, which I would like to extract every branch from it. The image is a portion of the tree, since the original tree is much bigger but it doesn’t fit well on a single image.

enter image description here

I’m not trying to print the rules of the tree like

Rules used to predict sample 1400:

decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)

or like:

The binary tree structure has 7 nodes and has the following tree structure:

node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
    node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
    node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.

What I’m trying to achieve is something like:

branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01

Basically, I’m trying to obtain every branch from the top to the leaf node (the full path) with the class and the gini score. How can I achieve this?

Advertisement

Answer

Considering the irist dataset example from sklearn docs we follow the next steps.

1.Generate an example Decision Tree

Code taken from docs

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)

2. Retrieve branches paths

First we retrieve the following values from the tree

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value

Inside retrieve_branches we calculate the leaf nodes and iterate from the origin node down to the leaf nodes, when we get to a leaf node we return the branch path with a yield statement.

def retrieve_branches(number_nodes, children_left_list, children_right_list):
    """Retrieve decision tree branches"""
    
    # Calculate if a node is a leaf
    is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
    
    # Store the branches paths
    paths = []
    
    for i in range(number_nodes):
        if is_leaves_list[i]:
            # Search leaf node in previous paths
            end_node = [path[-1] for path in paths]

            # If it is a leave node yield the path
            if i in end_node:
                output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
                yield output

        else:
            
            # Origin and end nodes
            origin, end_l, end_r = i, children_left_list[i], children_right_list[i]

            # Iterate over previous paths to add nodes
            for index, path in enumerate(paths):
                if origin == path[-1]:
                    paths[index] = path + [end_l]
                    paths.append(path + [end_r])

            # Initialize path in first iteration
            if i == 0:
                paths.append([i, children_left[i]])
                paths.append([i, children_right[i]])

To call the retrieve_branches just pass n_nodes, children_left and children_right and an empty list that will be storing and updating the branches paths. Final display is shown below.

all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>> 
[[0, 1],
 [0, 2, 3, 5],
 [0, 2, 3, 6, 7],
 [0, 2, 3, 6, 8],
 [0, 2, 4, 9],
 [0, 2, 4, 10]]

3. Path, value and Gini by Branch

The rules can be obtain from the feature and threshold values of the clf.tree_, as well as the impurity clf.tree_.impurity and the values clf.tree_.value at the leaf node.

for index, branch in enumerate(all_branches):
    leaf_index = branch[-1]
    print(f'Branch: {index}, Path: {branch}')
    print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
    print(f'Value: {value[leaf_index]}')
    print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
    print(f"---------------------------------------------------------------------------------------n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37.  0.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32.  0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0.  0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------

User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement