AssertionError: Exception occured in Recorder
when calling event after_batch
:
For lab1 (Is it a bird?) I’m trying to modify the learner to work in a multiclass setting, i.e. with 3 categories instead of 2.
According to the tutorial in the docs this is straightforward, but I’m having trouble with fine-tuning the model.
When re-training with an optimal learning rate for 3 epochs the following rather complex error message was produced:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/tmp/ipykernel_18/568062170.py in <module>
10 learn2=vision_learner(dls2, resnet18, metrics=[partial(accuracy_multi, thresh=0.5), fbeta_macro, fbeta_sample] )
11 learn2.lr_find() #find the optimal learning rate
---> 12 learn2.fine_tune(3, 1e-3) #args: n epochs, learning rate. Not in documentation for some reason...ugh
/opt/conda/lib/python3.7/site-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
163 "Fine tune with `Learner.freeze` for `freeze_epochs`, then with `Learner.unfreeze` for `epochs`, using discriminative LR."
164 self.freeze()
--> 165 self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
166 base_lr /= 2
167 self.unfreeze()
/opt/conda/lib/python3.7/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
118 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
120
121 # %% ../../nbs/14_callback.schedule.ipynb 50
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
262 self.opt.set_hypers(lr=self.lr if lr is None else lr)
263 self.n_epoch = n_epoch
--> 264 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
265
266 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
197
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_fit(self)
251 for epoch in range(self.n_epoch):
252 self.epoch=epoch
--> 253 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
254
255 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
197
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch(self)
246 def _do_epoch(self):
247 self._do_epoch_train()
--> 248 self._do_epoch_validate()
249
250 def _do_fit(self):
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
242 if dl is None: dl = self.dls[ds_idx]
243 self.dl = dl
--> 244 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
245
246 def _do_epoch(self):
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
197
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in all_batches(self)
203 def all_batches(self):
204 self.n_iter = len(self.dl)
--> 205 for o in enumerate(self.dl): self.one_batch(*o)
206
207 def _backward(self): self.loss_grad.backward()
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in one_batch(self, i, b)
233 b = self._set_device(b)
234 self._split(b)
--> 235 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
236
237 def _do_epoch_train(self):
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
--> 201 self(f'after_{event_type}'); final()
202
203 def all_batches(self):
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in __call__(self, event_name)
170
171 def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 172 def __call__(self, event_name): L(event_name).map(self._call_one)
173
174 def _call_one(self, event_name):
/opt/conda/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
154 def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
155
--> 156 def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs))
157 def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
158 def argfirst(self, f, negate=False):
/opt/conda/lib/python3.7/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
838 res = map(g, iterable)
839 if gen: return res
--> 840 return list(res)
841
842 # %% ../nbs/01_basics.ipynb 336
/opt/conda/lib/python3.7/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
823 if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
824 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 825 return self.func(*fargs, **kwargs)
826
827 # %% ../nbs/01_basics.ipynb 326
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in _call_one(self, event_name)
174 def _call_one(self, event_name):
175 if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 176 for cb in self.cbs.sorted('order'): cb(event_name)
177
178 def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)
/opt/conda/lib/python3.7/site-packages/fastai/callback/core.py in __call__(self, event_name)
60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
---> 62 except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
63 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
64 return res
/opt/conda/lib/python3.7/site-packages/fastai/callback/core.py in __call__(self, event_name)
58 res = None
59 if self.run and _run:
---> 60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
62 except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in after_batch(self)
558 if len(self.yb) == 0: return
559 mets = self._train_mets if self.training else self._valid_mets
--> 560 for met in mets: met.accumulate(self.learn)
561 if not self.training: return
562 self.lrs.append(self.opt.hypers[-1]['lr'])
/opt/conda/lib/python3.7/site-packages/fastai/learner.py in accumulate(self, learn)
480 def accumulate(self, learn):
481 bs = find_bs(learn.yb)
--> 482 self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
483 self.count += bs
484 @property
/opt/conda/lib/python3.7/site-packages/fastai/metrics.py in accuracy_multi(inp, targ, thresh, sigmoid)
199 def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
200 "Compute accuracy when `inp` and `targ` are the same size."
--> 201 inp,targ = flatten_check(inp,targ)
202 if sigmoid: inp = inp.sigmoid()
203 return ((inp>thresh)==targ.bool()).float().mean()
/opt/conda/lib/python3.7/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
785 "Check that `inp` and `targ` have the same number of elements and flatten them."
786 inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
--> 787 test_eq(len(inp), len(targ))
788 return inp,targ
789
/opt/conda/lib/python3.7/site-packages/fastcore/test.py in test_eq(a, b)
35 def test_eq(a,b):
36 "`test` that `a==b`"
---> 37 test(a,b,equals, cname='==')
38
39 # %% ../nbs/00_test.ipynb 25
/opt/conda/lib/python3.7/site-packages/fastcore/test.py in test(a, b, cmp, cname)
25 "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
26 if cname is None: cname=cmp.__name__
---> 27 assert cmp(a,b),f"{cname}:\n{a}\n{b}"
28
29 # %% ../nbs/00_test.ipynb 16
AssertionError: Exception occured in `Recorder` when calling event `after_batch`:
==:
96
32
Here is my code:
#establish which DL model to use
#here we need to change the perforamnce metric as unidim error rate won't work
#brie_pop=BrierScoreMulti(thresh=0.4) did not work due to non-mutual exclusivity of strictly proper scoring rule?
fbeta_macro=FBetaMulti(beta=1.1, thresh=0.5, average='macro')
fbeta_macro.name='FBeta (macro)'
fbeta_sample=FBetaMulti(beta=1.1, thresh=0.5, average='samples')
fbeta_macro.name='FBeta (samples)'
#multiclass modification
learn2=vision_learner(dls2, resnet18, metrics=[partial(accuracy_multi, thresh=0.5), fbeta_macro, fbeta_sample] )
learn2.lr_find() #find the optimal learning rate
learn2.fine_tune(3, 1e-3) #args: n epochs, learning rate. Not in documentation for some reason...ugh
Has anyone else experienced a similar error, or possesses the API expertise to find the root cause of this issue?
Thanks for your time!