In lesson 7, we’re introduced to the non-naive version of the find_better_split() function.
def std_agg(cnt, s1, s2): return math.sqrt((s2/cnt) - (s1/cnt)**2)
def find_better_split_foo(self, var_idx):
x,y = self.x.values[self.idxs,var_idx], self.y[self.idxs]
sort_idx = np.argsort(x)
sort_y,sort_x = y[sort_idx], x[sort_idx]
rhs_cnt,rhs_sum,rhs_sum2 = self.n, sort_y.sum(), (sort_y**2).sum()
lhs_cnt,lhs_sum,lhs_sum2 = 0,0.,0.
for i in range(0,self.n-self.min_leaf-1):
xi,yi = sort_x[i],sort_y[i]
lhs_cnt += 1; rhs_cnt -= 1
lhs_sum += yi; rhs_sum -= yi
lhs_sum2 += yi**2; rhs_sum2 -= yi**2
if i<self.min_leaf or xi==sort_x[i+1]:
continue
lhs_std = std_agg(lhs_cnt, lhs_sum, lhs_sum2)
rhs_std = std_agg(rhs_cnt, rhs_sum, rhs_sum2)
curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt
if curr_score<self.score:
self.var_idx,self.score,self.split = var_idx,curr_score,xi
As I understand it, if we had self.n == 100
, when min_leaf = 1
, in the for loop we will iterate through all the elements of the selected rows. However if min_leaf = 2
, then we will iterate through the first 99 items, but we will also skip the first element (due to the if statement: if<self.min_leaf or xi==sort_x[i+1]: continue
). To take it further, if we had min_leaf = 10
, then we would iterate through the first 90 elements, but also skip the first 10 items.
I could totally have misread this, and if so please correct me. But right now it seems that the min_leaf
parameter is forcing us to perform the split in the middle of the rows, but leaving the possibility open that the “true” best split element might be skipped out on. Is this the correct way to read this code? Thank you.