Confusion Matrix Assert Error on MNIST


I tried plotting the confusion matrix for the MNIST model trained in Lesson 13 of the 2020 course.

I am getting an assertion error.

Here is the code:

from import *
from import *

path = untar_data(URLs.MNIST)

def get_dls(bs=64):
    return DataBlock(
        blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
    ).dataloaders(path, bs=bs)

dls = get_dls()

def conv(ni, nf, ks=3, act=True):
    res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

simple_cnn = sequential(
    conv(1 ,4),            #14x14
    conv(4 ,8),            #7x7
    conv(8 ,16),           #4x4
    conv(16,32),           #2x2
    conv(32,10, act=False), #1x1

learn = Learner(dls, simple_cnn, loss_func=F.cross_entropy, metrics=accuracy)

interp = ClassificationInterpretation.from_learner(learn)

Here is the error:

AssertionError                            Traceback (most recent call last)
<ipython-input-19-64d44d20f204> in <module>()
     30 interp = ClassificationInterpretation.from_learner(learn)
---> 31 interp.plot_confusion_matrix()

4 frames
/usr/local/lib/python3.6/dist-packages/fastai/ in plot_confusion_matrix(self, normalize, title, cmap, norm_dec, plot_txt, **kwargs)
     68         "Plot the confusion matrix, with `title` and using `cmap`."
     69         # This function is mainly copied from the sklearn docs
---> 70         cm = self.confusion_matrix()
     71         if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
     72         fig = plt.figure(**kwargs)

/usr/local/lib/python3.6/dist-packages/fastai/ in confusion_matrix(self)
     60         "Confusion matrix as an `np.ndarray`."
     61         x = torch.arange(0, len(self.vocab))
---> 62         d,t = flatten_check(self.decoded, self.targs)
     63         cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
     64         return to_np(cm)

/usr/local/lib/python3.6/dist-packages/fastai/ in flatten_check(inp, targ)
    769     "Check that `out` and `targ` have the same number of elements and flatten them."
    770     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 771     test_eq(len(inp), len(targ))
    772     return inp,targ

/usr/local/lib/python3.6/dist-packages/fastcore/ in test_eq(a, b)
     33 def test_eq(a,b):
     34     "`test` that `a==b`"
---> 35     test(a,b,equals, '==')
     37 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/ in test(a, b, cmp, cname)
     23     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     24     if cname is None: cname=cmp.__name__
---> 25     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     27 # Cell

AssertionError: ==:

The error is the same in colab or running it locally for me.

Can someone help me out on this one?



Hello, sorry for answering this late.

I had the same Assert Error while experimenting on the whole MNIST dataset after chapter 4.

I noticed that if I don’t provide the Learner with an argument for loss_func, then ClassificationInterpretation will work just fine.

I’m not entirely clear on why that happens. In my case it was using nn.CrossEntropyLoss() with cnn_learner. Investigating how Learner initializes the loss function on its own might give some clues.


I am faced exact same issue and removing loss function in cnn_learner() call has avoided the issue. I am amazed that others have not seen it. It should be common place.

1 Like

Experiencing the same error when specifying my own loss function. Anyone else able to resolve?

1 Like

I faced the same problem.
When I tried nn.CossEntropyLoss specifying the weights, I got the AssertionError after interp.plot_confusion_matrix().
However, if I use CrossEntropyLossFlat() from fastai library rather than pytorch nn and specified the weigths there, there is no problem for when interp.plot_confusion_matrix() is called.
Not quite sure what is going on


I just faced the same issue today and the solution provided by @Joan was enough!

1 Like

Had the same issue just now, and turns out it is caused by me specifying a pytorch loss function.

You can either specify CrossEntropyLossFlat() as @Joan stated or don’t specify a loss function at all, because CrossEntropyLossFlat() is still the default loss function that cnn_learner chooses.

1 Like