I built in some functionality to allow for confusion matrix creation for tabular data. I essentially adapted the ClassificationInterpretation class in the vision section to work with tabular and made class ClassificationInterpretationTabular in /tabular/models.py. Thought this could be useful for others as well so I submitted a PR
ClassificationInterpretationTabular has methods plot_confusion_matrix which returns plot and confusion_matrix which returns numpy array.
@mchaykow, Great Job!!! Just wanted to put my $.02 in, here’s a snippet to get most_confused working:
def most_confused(self, min_val:int=1, slice_size:int=1)->Collection[Tuple[str,str,int]]:
"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
cm = self.confusion_matrix()
np.fill_diagonal(cm, 0)
res = [(self.data.classes[i],self.data.classes[j],cm[i,j])
for i,j in zip(*np.where(cm>=min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
The only thing that had to be adjusted was the confusion_matrix input. Not sure if you had discovered that yet or not
Hey, thanks for the follow up. I’m not sure removing the confusion matrix input here would make a difference unless your slice_size was different than 1 for the most_confused call because the default slice_size for confusion_matrix is already 1. Try setting slice_size=1 in your most_confused call and see if that fixes it as well.
Are you referencing this issue? I was under the impression @sgugger resolved this.
I think he did, apologies! I originally didn’t notice it in the github functions so I was not sure One thing I will ask though, did you find a way to get top_losses working? It would be a bit different than the normal image-related one to me, where instead we return the idx in the dataframe (if that’s not already what topk does?)
Looks like it does, but what I would rather have for a tabular dataset is a plot_top_losses function where it shows the dataframe row itself. Let me know if that is worth exploring
No there is not. I posted to the dev forum but I’m unsure if it is worth PRing as I’m fairly new to trying to get stuff implemented I’m trying out. How do I go about that?
You can fork the library and then add in your function. Then run the tests in the /tests/ directory and make sure everything is working ok and then follow the steps here. Commit your changes along the way as you update the code to make sure there’s a history of your edits. Then just wait for approval.
I also posted a gist of the tests to show they run ok as I commit new changes.