I’m trying to implement a decision tree from scratch, but when I tried testing it on the bluebook dataset used in lesson 7, it’s way too slow.
Here is the code I have written so far. This is the Node class.
class Node():
# the node of the decision tree
# goes left if less than or equal to threshold and right if greater than threshold
# column is the
def __init__(self, average, column=None, threshold=None, error=None, parent=None, left=None, right=None, is_leaf=False):
self.average = average
self.column = column
self.threshold = threshold
self.error = error
self.parent = parent
self.left = left
self.right = right
self.is_leaf = is_leaf
This is the DecisionTree class which creates the tree.
class DecisionTree():
def __init__(self, x, y, min_samples_leaf=1):
#x is the dataframe with independent variables
self.x = x
#y is the dependent variable (price)
self.y = y
#min_samples_leaf is the minimum number of items in the leaf
self.min_samples_leaf = min_samples_leaf
self.create_tree()
def predict(self, df):
current_node = self.root
while not current_node.is_leaf:
if df[current_node.column] <= current_node.threshold:
current_node = current_node.left
else:
current_node = current_node.right
return current_node.average, current_node.error
def create_tree(self):
initial_avg = self.y.mean()
initial_error = math.sqrt(((self.y - initial_avg)**2).sum()/len(self.y))
initial_idx = np.arange(len(self.x))
self.root = Node(average=initial_avg, error=initial_error)
self.split(initial_idx, self.root)
def split(self, idx, parent_node):
#idx are the indexes of the rows that are a part of the node
columns = self.x.columns
parent_is_leaf = True
best_error = parent_node.error
best_column = columns[0]
best_threshold = 0
best_left_error = sys.float_info.max
best_right_error = sys.float_info.max
best_left_idx = array([])
best_right_idx = array([])
#loop through columns
for col in columns:
for value in self.x[col].unique():
left_idx = np.where(self.x[col].iloc[idx] <= value)
right_idx = np.where(self.x[col].iloc[idx] > value)
if len(left_idx) < self.min_samples_leaf or len(right_idx) < self.min_samples_leaf:
continue
left_error = self.find_error(left_idx)
right_error = self.find_error(right_idx)
if left_error < best_error and right_error < best_error:
best_error = (left_error + right_error)/2
best_column = col
best_threshold = value
best_left_error = left_error
best_right_error = right_error
best_left_idx = left_idx
best_right_idx = right_idx
parent_is_leaf = False
if parent_is_leaf:
parent_node.is_leaf = True
else:
parent_node.column = best_column
parent_node.threshold = best_threshold
left_node = Node(average=find_average(best_left_idx), error=best_left_error, parent=parent_node)
right_node = Node(average=find_average(best_right_idx), error=best_right_error, parent=parent_node)
parent_node.left = left_node
parent_node.right = right_node
self.split(best_left_idx, left_node)
self.split(best_right_idx, right_node)
def find_error(self, idx):
#rmse
avg = self.y.iloc[idx].mean()
return math.sqrt(((self.y.iloc[idx] - avg)**2).sum()/len(idx))
def find_average(self, idx):
#average
return self.y.iloc[idx].mean()
The problem is in the split function. Specifically the lines
left_idx = np.where(self.x[col].iloc[idx] <= value)
right_idx = np.where(self.x[col].iloc[idx] > value)
It takes such a long time that when running the code, the recursion doesn’t even happen. Is there a way to speed this up or am I on the wrong path all together?