I have been using this tutorial to learn decision tree learning, and am now trying to understand how it works with higher dimensional datasets.
Currently my regressor predicts a Z value for an (x,y) pair that you pass to it.
JavaScript
x
25
25
1
import numpy as np
2
import matplotlib.pyplot as plt
3
from sklearn.tree import DecisionTreeRegressor
4
from mpl_toolkits import mplot3d
5
dataset = np.array(
6
[['Asset Flip', 100,100, 1000],
7
['Text Based', 500,300, 3000],
8
['Visual Novel', 1500,500, 5000],
9
['2D Pixel Art', 3500,300, 8000],
10
['2D Vector Art', 5000,900, 6500],
11
['Strategy', 6000,600, 7000],
12
['First Person Shooter', 8000,500, 15000],
13
['Simulator', 9500,400, 20000],
14
['Racing', 12000,300, 21000],
15
['RPG', 14000,150, 25000],
16
['Sandbox', 15500,200, 27000],
17
['Open-World', 16500,500, 30000],
18
['MMOFPS', 25000,600, 52000],
19
['MMORPG', 30000,700, 80000]
20
])
21
X = dataset[:, 1:3].astype(int)
22
y = dataset[:, 3].astype(int)
23
regressor = DecisionTreeRegressor(random_state = 0)
24
regressor.fit(X, y)
25
I want to use a 3d graph to visualise it, but I have struggled with the way regressor.predict() expects its inputs to be vs the way programs like matplotlib wireframes expect their inputs to be. As a result I have not been able to make them work together.
Advertisement
Answer
Try this, I do not have all the packages installed, so I tested this on google colab. Let me know if this is what you expected.
JavaScript
1
15
15
1
from mpl_toolkits.mplot3d import Axes3D
2
3
fig = plt.figure()
4
ax = fig.add_subplot(111, projection='3d')
5
# to just see the prediction results of your data
6
#ax.scatter(X[:, 0], X[:, 1], regressor.predict(regressor.predict(X)), c='g')
7
8
samples = 10
9
xx, yy = np.meshgrid(np.linspace(min(X[:,0]), max(X[:,0]), samples), np.linspace(min(X[:,1]), max(X[:,1]), samples))
10
# to see the decision boundaries(not the right word for a decision tree regressor, I think)
11
ax.plot_wireframe(xx, yy, regressor.predict(np.hstack((xx.reshape(-1,1), yy.reshape(-1,1)))).reshape(xx.shape))
12
ax.set_xlabel('x-axis')
13
ax.set_ylabel('y-axis')
14
ax.set_zlabel('z-axis(predictions)')
15