# Help Implementing Decision Tree from Scratch

I’m trying to implement a decision tree from scratch, but when I tried testing it on the bluebook dataset used in lesson 7, it’s way too slow.

Here is the code I have written so far. This is the Node class.

``````class Node():
# the node of the decision tree
# goes left if less than or equal to threshold and right if greater than threshold
# column is the
def __init__(self, average, column=None, threshold=None, error=None, parent=None, left=None, right=None, is_leaf=False):
self.average = average
self.column = column
self.threshold = threshold
self.error = error
self.parent = parent
self.left = left
self.right = right
self.is_leaf = is_leaf
``````

This is the DecisionTree class which creates the tree.

``````class DecisionTree():

def __init__(self, x, y, min_samples_leaf=1):
#x is the dataframe with independent variables
self.x = x
#y is the dependent variable (price)
self.y = y
#min_samples_leaf is the minimum number of items in the leaf
self.min_samples_leaf = min_samples_leaf
self.create_tree()

def predict(self, df):
current_node = self.root
while not current_node.is_leaf:
if df[current_node.column] <= current_node.threshold:
current_node = current_node.left
else:
current_node = current_node.right
return current_node.average, current_node.error

def create_tree(self):
initial_avg = self.y.mean()
initial_error = math.sqrt(((self.y - initial_avg)**2).sum()/len(self.y))
initial_idx = np.arange(len(self.x))
self.root = Node(average=initial_avg, error=initial_error)
self.split(initial_idx, self.root)

def split(self, idx, parent_node):
#idx are the indexes of the rows that are a part of the node
columns = self.x.columns
parent_is_leaf = True
best_error = parent_node.error
best_column = columns[0]
best_threshold = 0
best_left_error = sys.float_info.max
best_right_error = sys.float_info.max
best_left_idx = array([])
best_right_idx = array([])
#loop through columns
for col in columns:
for value in self.x[col].unique():
left_idx = np.where(self.x[col].iloc[idx] <= value)
right_idx = np.where(self.x[col].iloc[idx] > value)
if len(left_idx) < self.min_samples_leaf or len(right_idx) < self.min_samples_leaf:
continue
left_error = self.find_error(left_idx)
right_error = self.find_error(right_idx)
if left_error < best_error and right_error < best_error:
best_error = (left_error + right_error)/2
best_column = col
best_threshold = value
best_left_error = left_error
best_right_error = right_error
best_left_idx = left_idx
best_right_idx = right_idx
parent_is_leaf = False

if parent_is_leaf:
parent_node.is_leaf = True
else:
parent_node.column = best_column
parent_node.threshold = best_threshold
left_node = Node(average=find_average(best_left_idx), error=best_left_error, parent=parent_node)
right_node = Node(average=find_average(best_right_idx), error=best_right_error, parent=parent_node)
parent_node.left = left_node
parent_node.right = right_node
self.split(best_left_idx, left_node)
self.split(best_right_idx, right_node)

def find_error(self, idx):
#rmse
avg = self.y.iloc[idx].mean()
return math.sqrt(((self.y.iloc[idx] - avg)**2).sum()/len(idx))

def find_average(self, idx):
#average
return self.y.iloc[idx].mean()
``````

The problem is in the split function. Specifically the lines

``````left_idx = np.where(self.x[col].iloc[idx] <= value)
right_idx = np.where(self.x[col].iloc[idx] > value)
``````

It takes such a long time that when running the code, the recursion doesn’t even happen. Is there a way to speed this up or am I on the wrong path all together?

Try to find the operations that you repeat but don’t need to. For example, once you get `left_idx`, `right_idx` will just be all other indexes, so you don’t need to calculate `right_idx`. Then, if you order the values before doing the loop, you can improve the calculation of indexes. For example, if you do an ascending sort, once the value of an index belongs to the left arm of the tree you won’t need to check that again in that loop, as the rest of values will be bigger. Something similar happens when calculating the error, Jeremy goes into some detail of this in one of the lessons.

1 Like

I found a way to quickly get the indexes. Now, it’s `find_error()` that is slowing down the program. What is the faster way you mentioned?

I found the link, check this: https://youtu.be/O5F9vR2CNYI?t=2732