I have this decision tree, which I would like to extract every branch from it. The image is a portion of the tree, since the original tree is much bigger but it doesn’t fit well on a single image.
I’m not trying to print the rules of the tree like
Rules used to predict sample 1400: decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117) decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)
or like:
The binary tree structure has 7 nodes and has the following tree structure: node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4. node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3. node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.
What I’m trying to achieve is something like:
branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5 branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021 branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092 branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01
Basically, I’m trying to obtain every branch from the top to the leaf node (the full path) with the class and the gini score. How can I achieve this?
Advertisement
Answer
Considering the irist dataset example from sklearn docs we follow the next steps.
1.Generate an example Decision Tree
Code taken from docs
from matplotlib import pyplot as plt from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn import tree import numpy as np iris = load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0) clf.fit(X_train, y_train)
2. Retrieve branches paths
First we retrieve the following values from the tree
n_nodes = clf.tree_.node_count children_left = clf.tree_.children_left children_right = clf.tree_.children_right feature = clf.tree_.feature threshold = clf.tree_.threshold impurity = clf.tree_.impurity value = clf.tree_.value
Inside retrieve_branches
we calculate the leaf nodes and iterate from the origin node down to the leaf nodes, when we get to a leaf node we return the branch path with a yield
statement.
def retrieve_branches(number_nodes, children_left_list, children_right_list): """Retrieve decision tree branches""" # Calculate if a node is a leaf is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)] # Store the branches paths paths = [] for i in range(number_nodes): if is_leaves_list[i]: # Search leaf node in previous paths end_node = [path[-1] for path in paths] # If it is a leave node yield the path if i in end_node: output = paths.pop(np.argwhere(i == np.array(end_node))[0][0]) yield output else: # Origin and end nodes origin, end_l, end_r = i, children_left_list[i], children_right_list[i] # Iterate over previous paths to add nodes for index, path in enumerate(paths): if origin == path[-1]: paths[index] = path + [end_l] paths.append(path + [end_r]) # Initialize path in first iteration if i == 0: paths.append([i, children_left[i]]) paths.append([i, children_right[i]])
To call the retrieve_branches
just pass n_nodes
, children_left
and children_right
and an empty list that will be storing and updating the branches paths. Final display is shown below.
all_branches = list(retrieve_branches(n_nodes, children_left, children_right)) all_branches >>> [[0, 1], [0, 2, 3, 5], [0, 2, 3, 6, 7], [0, 2, 3, 6, 8], [0, 2, 4, 9], [0, 2, 4, 10]]
3. Path, value and Gini by Branch
The rules can be obtain from the feature
and threshold
values of the clf.tree_
, as well as the impurity clf.tree_.impurity
and the values clf.tree_.value
at the leaf node.
for index, branch in enumerate(all_branches): leaf_index = branch[-1] print(f'Branch: {index}, Path: {branch}') print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}') print(f'Value: {value[leaf_index]}') print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}") print(f"---------------------------------------------------------------------------------------n") >>> Branch: 0, Path: [0, 1] Gin 0.0 at leaf node 1 Value: [[37. 0. 0.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0'] --------------------------------------------------------------------------------------- Branch: 1, Path: [0, 2, 3, 5] Gin 0.0 at leaf node 5 Value: [[ 0. 32. 0.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0'] --------------------------------------------------------------------------------------- Branch: 2, Path: [0, 2, 3, 6, 7] Gin 0.0 at leaf node 7 Value: [[0. 0. 3.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0'] --------------------------------------------------------------------------------------- Branch: 3, Path: [0, 2, 3, 6, 8] Gin 0.0 at leaf node 8 Value: [[0. 1. 0.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0'] --------------------------------------------------------------------------------------- Branch: 4, Path: [0, 2, 4, 9] Gin 0.375 at leaf node 9 Value: [[0. 1. 3.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0'] --------------------------------------------------------------------------------------- Branch: 5, Path: [0, 2, 4, 10] Gin 0.0 at leaf node 10 Value: [[ 0. 0. 35.]] Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0'] ---------------------------------------------------------------------------------------