PR: Confusion matrix for Tabular data

I built in some functionality to allow for confusion matrix creation for tabular data. I essentially adapted the ClassificationInterpretation class in the vision section to work with tabular and made class ClassificationInterpretationTabular in /tabular/ Thought this could be useful for others as well so I submitted a PR

ClassificationInterpretationTabular has methods plot_confusion_matrix which returns plot and confusion_matrix which returns numpy array.

I also included a test in /tests/ and a gist showing the test being run:


@mchaykow, Great Job!!! Just wanted to put my $.02 in, here’s a snippet to get most_confused working:

def most_confused(self, min_val:int=1, slice_size:int=1)->Collection[Tuple[str,str,int]]:
            "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
            cm = self.confusion_matrix()
            np.fill_diagonal(cm, 0)
            res = [([i],[j],cm[i,j])
                    for i,j in zip(*np.where(cm>=min_val))]
            return sorted(res, key=itemgetter(2), reverse=True)

The only thing that had to be adjusted was the confusion_matrix input. Not sure if you had discovered that yet or not :slight_smile:

Hey, thanks for the follow up. I’m not sure removing the confusion matrix input here would make a difference unless your slice_size was different than 1 for the most_confused call because the default slice_size for confusion_matrix is already 1. Try setting slice_size=1 in your most_confused call and see if that fixes it as well.

Are you referencing this issue? I was under the impression @sgugger resolved this.

1 Like

I think he did, apologies! I originally didn’t notice it in the github functions so I was not sure :slight_smile: One thing I will ask though, did you find a way to get top_losses working? It would be a bit different than the normal image-related one to me, where instead we return the idx in the dataframe (if that’s not already what topk does?)

Looks like it does, but what I would rather have for a tabular dataset is a plot_top_losses function where it shows the dataframe row itself. Let me know if that is worth exploring

1 Like

That would be nice actually.

So, I actually did it already, oops. Was ‘fairly’ simple to figure out after I reversed how to map it all back together.

The code is here:

Nice, is there a PR?

No there is not. I posted to the dev forum but I’m unsure if it is worth PRing as I’m fairly new to trying to get stuff implemented I’m trying out. How do I go about that? :smile:

You can fork the library and then add in your function. Then run the tests in the /tests/ directory and make sure everything is working ok and then follow the steps here. Commit your changes along the way as you update the code to make sure there’s a history of your edits. Then just wait for approval.

I also posted a gist of the tests to show they run ok as I commit new changes.

1 Like

Thanks :slight_smile: it just got approved!

1 Like

Nice! Can you share the link?


Though I’m currently running into some issues as some changes were made to so it’s not useable yet, apologies! Working on getting there

It’s useable now. If you want to do this for a labeled test set, please see my post about it here: