Help Implementing Decision Tree from Scratch

I’m trying to implement a decision tree from scratch, but when I tried testing it on the bluebook dataset used in lesson 7, it’s way too slow.

Here is the code I have written so far. This is the Node class.

class Node():
    # the node of the decision tree
    # goes left if less than or equal to threshold and right if greater than threshold
    # column is the 
    def __init__(self, average, column=None, threshold=None, error=None, parent=None, left=None, right=None, is_leaf=False):
        self.average = average
        self.column = column
        self.threshold = threshold
        self.error = error
        self.parent = parent
        self.left = left
        self.right = right
        self.is_leaf = is_leaf

This is the DecisionTree class which creates the tree.

class DecisionTree():

    def __init__(self, x, y, min_samples_leaf=1):
        #x is the dataframe with independent variables
        self.x = x
        #y is the dependent variable (price)
        self.y = y
        #min_samples_leaf is the minimum number of items in the leaf
        self.min_samples_leaf = min_samples_leaf

    def predict(self, df):
        current_node = self.root
        while not current_node.is_leaf:
            if df[current_node.column] <= current_node.threshold:
                current_node = current_node.left
                current_node = current_node.right
        return current_node.average, current_node.error

    def create_tree(self):     
        initial_avg = self.y.mean()
        initial_error = math.sqrt(((self.y - initial_avg)**2).sum()/len(self.y))
        initial_idx = np.arange(len(self.x))
        self.root = Node(average=initial_avg, error=initial_error)
        self.split(initial_idx, self.root)

    def split(self, idx, parent_node):
        #idx are the indexes of the rows that are a part of the node
        columns = self.x.columns
        parent_is_leaf = True
        best_error = parent_node.error
        best_column = columns[0]
        best_threshold = 0
        best_left_error = sys.float_info.max
        best_right_error = sys.float_info.max
        best_left_idx = array([])
        best_right_idx = array([])
        #loop through columns
        for col in columns:
            for value in self.x[col].unique():
                left_idx = np.where(self.x[col].iloc[idx] <= value)
                right_idx = np.where(self.x[col].iloc[idx] > value)
                if len(left_idx) < self.min_samples_leaf or len(right_idx) < self.min_samples_leaf:
                left_error = self.find_error(left_idx)
                right_error = self.find_error(right_idx)
                if left_error < best_error and right_error < best_error:
                    best_error = (left_error + right_error)/2
                    best_column = col
                    best_threshold = value
                    best_left_error = left_error
                    best_right_error = right_error
                    best_left_idx = left_idx
                    best_right_idx = right_idx
                    parent_is_leaf = False

        if parent_is_leaf:
            parent_node.is_leaf = True
            parent_node.column = best_column
            parent_node.threshold = best_threshold
            left_node = Node(average=find_average(best_left_idx), error=best_left_error, parent=parent_node)
            right_node = Node(average=find_average(best_right_idx), error=best_right_error, parent=parent_node)
            parent_node.left = left_node
            parent_node.right = right_node
            self.split(best_left_idx, left_node)
            self.split(best_right_idx, right_node)

    def find_error(self, idx):
        avg = self.y.iloc[idx].mean()
        return math.sqrt(((self.y.iloc[idx] - avg)**2).sum()/len(idx))

    def find_average(self, idx):
        return self.y.iloc[idx].mean()

The problem is in the split function. Specifically the lines

left_idx = np.where(self.x[col].iloc[idx] <= value)
right_idx = np.where(self.x[col].iloc[idx] > value)

It takes such a long time that when running the code, the recursion doesn’t even happen. Is there a way to speed this up or am I on the wrong path all together?

Try to find the operations that you repeat but don’t need to. For example, once you get left_idx, right_idx will just be all other indexes, so you don’t need to calculate right_idx. Then, if you order the values before doing the loop, you can improve the calculation of indexes. For example, if you do an ascending sort, once the value of an index belongs to the left arm of the tree you won’t need to check that again in that loop, as the rest of values will be bigger. Something similar happens when calculating the error, Jeremy goes into some detail of this in one of the lessons.

1 Like

I found a way to quickly get the indexes. Now, it’s find_error() that is slowing down the program. What is the faster way you mentioned?

I found the link, check this: