[Fastbook] Chapter 7: plot_confusion_matrix() returns error when using mixup in cnn_leaner

Hi, I am new to fastai and just wanna test out “mixup” callback as mentioned on chapter 7 of the fastbook.

model_arch = resnet18
learn = cnn_learner(dls, model_arch, loss_func= CrossEntropyLossFlat(), cbs=MixUp(), metrics = accuracy, model_dir = model_path)

When I tried to call “plot_confusion_matrix()” method, it return the below message. I also noticed that I got a prediction score range not between 0 and 1 (from 1.2 to 5 approximately) for some of my images. Is this a bug from fastai?

AssertionError Traceback (most recent call last)
1 interp = ClassificationInterpretation.from_learner(learn)
----> 2 interp.plot_confusion_matrix()

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/interpret.py in plot_confusion_matrix(self, normalize, title, cmap, norm_dec, plot_txt, **kwargs)
68 “Plot the confusion matrix, with title and using cmap.”
69 # This function is mainly copied from the sklearn docs
—> 70 cm = self.confusion_matrix()
71 if normalize: cm = cm.astype(‘float’) / cm.sum(axis=1)[:, np.newaxis]
72 fig = plt.figure(**kwargs)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/interpret.py in confusion_matrix(self)
60 “Confusion matrix as an np.ndarray.”
61 x = torch.arange(0, len(self.vocab))
—> 62 d,t = flatten_check(self.decoded, self.targs)
63 cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
64 return to_np(cm)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
751 “Check that out and targ have the same number of elements and flatten them.”
752 inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
–> 753 test_eq(len(inp), len(targ))
754 return inp,targ

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastcore/test.py in test_eq(a, b)
32 def test_eq(a,b):
33 "test that a==b"
—> 34 test(a,b,equals, ‘==’)
36 # Cell

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/fastcore/test.py in test(a, b, cmp, cname)
22 "assert that cmp(a,b); display inputs and cname or cmp.__name__ if it fails"
23 if cname is None: cname=cmp.name
—> 24 assert cmp(a,b),f"{cname}:\n{a}\n{b}"
26 # Cell

AssertionError: ==:

I am having the exact same problem. Anyone have any thoughts on this?

You need to remove the MixUp callback before moving on, otherwise it will break everything


Just realized this not 2 seconds ago. LOL

learn.remove_cbs([ShowGraphCallback, cutmix])


In general if the callback is something you only want during training, you should pass it into your fit function rather than the Learner

1 Like

Can you please enlighten me how you found the solution? I had exactly the same problem and would not know where to start to answer the problem by myself. Thanks!

Since we passed mixup to cbs, the first thing would be to check what learn.cbs gives. Anything in there is run any time we call learn for anything related to predictions or training. If we then check Learner for anything callback related in functions, there’s an add_ and remove_cbs functions (as during training, if say we passed in a callback, it’d need some way to know to remove and add that callback)

Also: how to tell it was mixup doing something: mixup affects the y values in our data, so if something would normally run and it gives me a mismatch error, and I know it’s not a loss or metric, it’s a callback that affects the data

Thank you for your reply. My question refers to exactly your line of arguments. After having read “Deep Learning for Coders with fastai and PyTorch: AI Applications Without a PhD” and nothing else of the fastai content I would not be able at all to come to your conclusions. Therefore I am wondering what sources you consumed to aquire this deep knowledge.

No sources. Just go and play with the library and read the documentation to get a true feel for it. For instance, Learner's docs mention the callback handling: https://docs.fast.ai/learner#Callback-handling


I tried this still no change. I checked my call backs it had [Recorder,TrainEvalCallback,ProgressCallback]. I tried removing all of them and check but no luck. It’s giving same Assertion error.

Why do you know all these? Thx for answering.

1 Like