Use of plot_confusion_matrix seems inaccurate

Just realized that some of my confusion matrices were wrong in terms of labels. By digging further and looking at plot_confusion_matrix implementation, I see that using a dict in xtick is not reliable. Here is an example to prove my point.
Notice the difference in ordering of cat(s)/dog(s) labels on axises. Column/Row 1 should correspond to cats and column/row 2 for dogs in both cases.

Passing a list is more reliable and can be built from class_indices as following:
import operator
labels = sorted(val_batches.class_indices.keys(), key=operator.itemgetter(0))
plot_confusion_matrix(cm, labels)

I am not entirely sure why dict is unreliable but my guess is it is using the hash of the key to order instead of the values. And it is possible hash of cats is lower than dogs. But that might not be true for cat/dog

3 Likes

Thanks for your thorough analysis of this and helpful recommendation :slight_smile:

1 Like