Find misclassified rows in a TabularPandas

Hello! I build this easy to use function to help find misclassified rows in a pandas dataset. Hopefully it will be of help for some of you guys!

# Define to as a TabularPandas
# y_pred as predicted y classes
# targs are the true y labels
def find_misclassified(pred_value,true_value):
    a = to_np(y_pred)
    b = to_np(targs.squeeze())
    idxs = np.where( (a==pred_value) & (b==true_value) )[0]
    return to.valid.decode().items.iloc[idxs,:]
find_misclassified(0,1)
2 Likes