Hello,
I tried plotting the confusion matrix for the MNIST model trained in Lesson 13 of the 2020 course.
I am getting an assertion error.
Here is the code:
from fastai.data.all import *
from fastai.vision.all import *
path = untar_data(URLs.MNIST)
def get_dls(bs=64):
return DataBlock(
blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
get_items=get_image_files,
splitter=GrandparentSplitter('training','testing'),
get_y=parent_label,
batch_tfms=Normalize()
).dataloaders(path, bs=bs)
dls = get_dls()
def conv(ni, nf, ks=3, act=True):
res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
if act: res = nn.Sequential(res, nn.ReLU())
return res
simple_cnn = sequential(
conv(1 ,4), #14x14
conv(4 ,8), #7x7
conv(8 ,16), #4x4
conv(16,32), #2x2
conv(32,10, act=False), #1x1
Flatten(),
)
learn = Learner(dls, simple_cnn, loss_func=F.cross_entropy, metrics=accuracy)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
Here is the error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-19-64d44d20f204> in <module>()
29
30 interp = ClassificationInterpretation.from_learner(learn)
---> 31 interp.plot_confusion_matrix()
4 frames
/usr/local/lib/python3.6/dist-packages/fastai/interpret.py in plot_confusion_matrix(self, normalize, title, cmap, norm_dec, plot_txt, **kwargs)
68 "Plot the confusion matrix, with `title` and using `cmap`."
69 # This function is mainly copied from the sklearn docs
---> 70 cm = self.confusion_matrix()
71 if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
72 fig = plt.figure(**kwargs)
/usr/local/lib/python3.6/dist-packages/fastai/interpret.py in confusion_matrix(self)
60 "Confusion matrix as an `np.ndarray`."
61 x = torch.arange(0, len(self.vocab))
---> 62 d,t = flatten_check(self.decoded, self.targs)
63 cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
64 return to_np(cm)
/usr/local/lib/python3.6/dist-packages/fastai/torch_core.py in flatten_check(inp, targ)
769 "Check that `out` and `targ` have the same number of elements and flatten them."
770 inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 771 test_eq(len(inp), len(targ))
772 return inp,targ
/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test_eq(a, b)
33 def test_eq(a,b):
34 "`test` that `a==b`"
---> 35 test(a,b,equals, '==')
36
37 # Cell
/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test(a, b, cmp, cname)
23 "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
24 if cname is None: cname=cmp.__name__
---> 25 assert cmp(a,b),f"{cname}:\n{a}\n{b}"
26
27 # Cell
AssertionError: ==:
100000
10000
The error is the same in colab or running it locally for me.
Can someone help me out on this one?
Regards,
Christian