Unet_learner Learner fine_tune() fails with "metrics" arguments

I’m trying to add metrics to a Unet model created using the unet_learner() method.

I used the stock camvid example but keep getting this Assertion Error. I’ve tried accuracy_multi, JaccardMulti, and F1ScoreMulti without success.

from fastai.vision.all import *
path = untar_data(URLs.CAMVID_TINY)
codes = np.loadtxt(path/'codes.txt', dtype=str)
fnames = get_image_files(path/"images")

def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"

dls = SegmentationDataLoaders.from_label_func(
    path, bs=8, fnames = fnames, label_func = label_func, codes = codes
)

learn = unet_learner(dls, resnet34, metrics=[accuracy_multi]) 
learn.fine_tune(6)

Omitting the arguments to metrics removes the error and I can train normally. So, this works fine: learn = unet_learner(dls, resnet34) followed by learn.fine_tune(6).

This is the error I get:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-8-1606e5742a18> in <module>
      1 learn = unet_learner(dls, resnet34, metrics=[accuracy_multi])
----> 2 learn.fine_tune(6)

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
    156     "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
    157     self.freeze()
--> 158     self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    159     base_lr /= 2
    160     self.unfreeze()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    219             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    220             self.n_epoch = n_epoch
--> 221             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    222 
    223     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    210         for epoch in range(self.n_epoch):
    211             self.epoch=epoch
--> 212             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    213 
    214     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    205     def _do_epoch(self):
    206         self._do_epoch_train()
--> 207         self._do_epoch_validate()
    208 
    209     def _do_fit(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    192         b = self._set_device(b)
    193         self._split(b)
--> 194         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    195 
    196     def _do_epoch_train(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
--> 165         self(f'after_{event_type}');  final()
    166 
    167     def all_batches(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
    139 
    140     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141     def __call__(self, event_name): L(event_name).map(self._call_one)
    142 
    143     def _call_one(self, event_name):

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
    143     def _call_one(self, event_name):
    144         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145         for cb in self.cbs.sorted('order'): cb(event_name)
    146 
    147     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
     43                (self.run_valid and not getattr(self, 'training', False)))
     44         res = None
---> 45         if self.run and _run: res = getattr(self, event_name, noop)()
     46         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     47         return res

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in after_batch(self)
    502         if len(self.yb) == 0: return
    503         mets = self._train_mets if self.training else self._valid_mets
--> 504         for met in mets: met.accumulate(self.learn)
    505         if not self.training: return
    506         self.lrs.append(self.opt.hypers[-1]['lr'])

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in accumulate(self, learn)
    424     def accumulate(self, learn):
    425         bs = find_bs(learn.yb)
--> 426         self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
    427         self.count += bs
    428     @property

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/metrics.py in accuracy_multi(inp, targ, thresh, sigmoid)
    193 def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
    194     "Compute accuracy when `inp` and `targ` are the same size."
--> 195     inp,targ = flatten_check(inp,targ)
    196     if sigmoid: inp = inp.sigmoid()
    197     return ((inp>thresh)==targ.bool()).float().mean()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
    810     "Check that `out` and `targ` have the same number of elements and flatten them."
    811     inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
--> 812     test_eq(len(inp), len(targ))
    813     return inp,targ

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test_eq(a, b)
     34 def test_eq(a,b):
     35     "`test` that `a==b`"
---> 36     test(a,b,equals, '==')
     37 
     38 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test(a, b, cmp, cname)
     24     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     25     if cname is None: cname=cmp.__name__
---> 26     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     27 
     28 # Cell

AssertionError: ==:
3145728
98304 

I’m training on a cluster with the following output from fastai.test_utils.show_install():

=== Software === 
python        : 3.8.2
fastai        : 2.4
fastcore      : 1.3.20
fastprogress  : 0.2.7
torch         : 1.9.0
nvidia driver : 450.51
torch cuda    : 10.2 / is available
torch cudnn   : 7605 / is enabled

=== Hardware === 
nvidia gpus   : 8
torch devices : 2
  - gpu0      : Tesla K80
  - gpu1      : Tesla K80

=== Environment === 
platform      : Linux-3.10.0-1127.19.1.el7.x86_64-x86_64-with-glibc2.10
distro        : #1 SMP Tue Aug 25 17:23:54 UTC 2020
conda env     : dl
python        : /home/usr/.conda/envs/dl/bin/python
sys.path      : 
/home/usr/.conda/envs/dl/lib/python38.zip
/home/usr/.conda/envs/dl/lib/python3.8
/home/usr/.conda/envs/dl/lib/python3.8/lib-dynload
/home/usr/.local/lib/python3.8/site-packages
/home/usr/.conda/envs/dl/lib/python3.8/site-packages

The comment string for accuracy_multi says

Compute accuracy when inp and targ are the same size.

The error message you’re getting (3 frames up from the bottom) has a comment that says

“Check that out and targ have the same number of elements and flatten them.”

and that check eventually leads to the AttributeError. I just looked at the files for camvid_tiny, and those input images and labels are indeed the same size.

I suggest checking your inputs and targets are the same size, and if not, that’s the problem.

I just checked the image files and they are the same size for all images and masks as you say. So why do I still get the dimension error? How would the model be able to train without metrics arguments if the dimensions were really an issue? Shouldn’t the unet_learner be able to handle both with and without metrics cases?

Since they start the same, there must be some place where the images are getting resized, but I don’t have control over that using the stock unet_learner and SegmentationDataLoaders.

This feels like a bug in the expected behavior. Do I have to skip the convenience functions and write the Learner from scratch? Or maybe implement my own metrics that have flattening?

Accuracy multi is meant for multilabel problems. In this case it would be if each pixel in the segmentation problem had multiple labels. It does not, so you just want accuracy instead as each pixel is a single label with a single correct prediction.

Ahh, I see. I thought that because there were multiple labels in the same image, this counted as a ‘multilabel’ problem. Thanks for the clarification.

However, I still see the same Assertion Error using accuracy instead of accuracy_multi (still using the stock camvid example):

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-7-fe42416eeef3> in <module>
      9 
     10 learn = unet_learner(dls, resnet34, metrics=[accuracy])
---> 11 learn.fine_tune(6)

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
    156     "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
    157     self.freeze()
--> 158     self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    159     base_lr /= 2
    160     self.unfreeze()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    219             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    220             self.n_epoch = n_epoch
--> 221             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    222 
    223     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    210         for epoch in range(self.n_epoch):
    211             self.epoch=epoch
--> 212             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    213 
    214     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    205     def _do_epoch(self):
    206         self._do_epoch_train()
--> 207         self._do_epoch_validate()
    208 
    209     def _do_fit(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    192         b = self._set_device(b)
    193         self._split(b)
--> 194         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    195 
    196     def _do_epoch_train(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
--> 165         self(f'after_{event_type}');  final()
    166 
    167     def all_batches(self):

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
    139 
    140     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141     def __call__(self, event_name): L(event_name).map(self._call_one)
    142 
    143     def _call_one(self, event_name):

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
    143     def _call_one(self, event_name):
    144         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145         for cb in self.cbs.sorted('order'): cb(event_name)
    146 
    147     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
     43                (self.run_valid and not getattr(self, 'training', False)))
     44         res = None
---> 45         if self.run and _run: res = getattr(self, event_name, noop)()
     46         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     47         return res

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in after_batch(self)
    502         if len(self.yb) == 0: return
    503         mets = self._train_mets if self.training else self._valid_mets
--> 504         for met in mets: met.accumulate(self.learn)
    505         if not self.training: return
    506         self.lrs.append(self.opt.hypers[-1]['lr'])

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in accumulate(self, learn)
    424     def accumulate(self, learn):
    425         bs = find_bs(learn.yb)
--> 426         self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
    427         self.count += bs
    428     @property

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/metrics.py in accuracy(inp, targ, axis)
     97 def accuracy(inp, targ, axis=-1):
     98     "Compute accuracy with `targ` when `pred` is bs * n_classes"
---> 99     pred,targ = flatten_check(inp.argmax(dim=axis), targ)
    100     return (pred == targ).float().mean()
    101 

~/.conda/envs/dl/lib/python3.8/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
    810     "Check that `out` and `targ` have the same number of elements and flatten them."
    811     inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
--> 812     test_eq(len(inp), len(targ))
    813     return inp,targ

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test_eq(a, b)
     34 def test_eq(a,b):
     35     "`test` that `a==b`"
---> 36     test(a,b,equals, '==')
     37 
     38 # Cell

~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test(a, b, cmp, cname)
     24     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     25     if cname is None: cname=cmp.__name__
---> 26     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     27 
     28 # Cell

AssertionError: ==:
24576
98304

Any further help is much appreciated

Try this?

metrics = [partial(accuracy, axis=1)]

Yes, this worked!! I was also able to add the Jaccard metric with: metrics=[Jaccard(axis=1, average='micro'), partial(accuracy, axis=1)]) . Also, I was able to use this on my own dataset.

Thank you so much!

1 Like