I’m trying to add metrics to a Unet model created using the unet_learner()
method.
I used the stock camvid example but keep getting this Assertion Error. I’ve tried accuracy_multi
, JaccardMulti
, and F1ScoreMulti
without success.
from fastai.vision.all import *
path = untar_data(URLs.CAMVID_TINY)
codes = np.loadtxt(path/'codes.txt', dtype=str)
fnames = get_image_files(path/"images")
def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
dls = SegmentationDataLoaders.from_label_func(
path, bs=8, fnames = fnames, label_func = label_func, codes = codes
)
learn = unet_learner(dls, resnet34, metrics=[accuracy_multi])
learn.fine_tune(6)
Omitting the arguments to metrics
removes the error and I can train normally. So, this works fine: learn = unet_learner(dls, resnet34)
followed by learn.fine_tune(6)
.
This is the error I get:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-8-1606e5742a18> in <module>
1 learn = unet_learner(dls, resnet34, metrics=[accuracy_multi])
----> 2 learn.fine_tune(6)
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
156 "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
157 self.freeze()
--> 158 self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
159 base_lr /= 2
160 self.unfreeze()
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
111 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
112 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
114
115 # Cell
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
219 self.opt.set_hypers(lr=self.lr if lr is None else lr)
220 self.n_epoch = n_epoch
--> 221 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
222
223 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
161
162 def _with_events(self, f, event_type, ex, final=noop):
--> 163 try: self(f'before_{event_type}'); f()
164 except ex: self(f'after_cancel_{event_type}')
165 self(f'after_{event_type}'); final()
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
210 for epoch in range(self.n_epoch):
211 self.epoch=epoch
--> 212 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
213
214 def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
161
162 def _with_events(self, f, event_type, ex, final=noop):
--> 163 try: self(f'before_{event_type}'); f()
164 except ex: self(f'after_cancel_{event_type}')
165 self(f'after_{event_type}'); final()
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
205 def _do_epoch(self):
206 self._do_epoch_train()
--> 207 self._do_epoch_validate()
208
209 def _do_fit(self):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
201 if dl is None: dl = self.dls[ds_idx]
202 self.dl = dl
--> 203 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
204
205 def _do_epoch(self):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
161
162 def _with_events(self, f, event_type, ex, final=noop):
--> 163 try: self(f'before_{event_type}'); f()
164 except ex: self(f'after_cancel_{event_type}')
165 self(f'after_{event_type}'); final()
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
167 def all_batches(self):
168 self.n_iter = len(self.dl)
--> 169 for o in enumerate(self.dl): self.one_batch(*o)
170
171 def _do_one_batch(self):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
192 b = self._set_device(b)
193 self._split(b)
--> 194 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
195
196 def _do_epoch_train(self):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
163 try: self(f'before_{event_type}'); f()
164 except ex: self(f'after_cancel_{event_type}')
--> 165 self(f'after_{event_type}'); final()
166
167 def all_batches(self):
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
139
140 def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141 def __call__(self, event_name): L(event_name).map(self._call_one)
142
143 def _call_one(self, event_name):
~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
152 def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
153
--> 154 def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
155 def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
156 def filter(self, f=noop, negate=False, gen=False, **kwargs):
~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
664 res = map(g, iterable)
665 if gen: return res
--> 666 return list(res)
667
668 # Cell
~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
649 if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
650 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651 return self.func(*fargs, **kwargs)
652
653 # Cell
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
143 def _call_one(self, event_name):
144 if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145 for cb in self.cbs.sorted('order'): cb(event_name)
146
147 def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
43 (self.run_valid and not getattr(self, 'training', False)))
44 res = None
---> 45 if self.run and _run: res = getattr(self, event_name, noop)()
46 if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
47 return res
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in after_batch(self)
502 if len(self.yb) == 0: return
503 mets = self._train_mets if self.training else self._valid_mets
--> 504 for met in mets: met.accumulate(self.learn)
505 if not self.training: return
506 self.lrs.append(self.opt.hypers[-1]['lr'])
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/learner.py in accumulate(self, learn)
424 def accumulate(self, learn):
425 bs = find_bs(learn.yb)
--> 426 self.total += learn.to_detach(self.func(learn.pred, *learn.yb))*bs
427 self.count += bs
428 @property
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/metrics.py in accuracy_multi(inp, targ, thresh, sigmoid)
193 def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
194 "Compute accuracy when `inp` and `targ` are the same size."
--> 195 inp,targ = flatten_check(inp,targ)
196 if sigmoid: inp = inp.sigmoid()
197 return ((inp>thresh)==targ.bool()).float().mean()
~/.conda/envs/dl/lib/python3.8/site-packages/fastai/torch_core.py in flatten_check(inp, targ)
810 "Check that `out` and `targ` have the same number of elements and flatten them."
811 inp,targ = TensorBase(inp.contiguous()).view(-1),TensorBase(targ.contiguous()).view(-1)
--> 812 test_eq(len(inp), len(targ))
813 return inp,targ
~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test_eq(a, b)
34 def test_eq(a,b):
35 "`test` that `a==b`"
---> 36 test(a,b,equals, '==')
37
38 # Cell
~/.conda/envs/dl/lib/python3.8/site-packages/fastcore/test.py in test(a, b, cmp, cname)
24 "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
25 if cname is None: cname=cmp.__name__
---> 26 assert cmp(a,b),f"{cname}:\n{a}\n{b}"
27
28 # Cell
AssertionError: ==:
3145728
98304
I’m training on a cluster with the following output from fastai.test_utils.show_install()
:
=== Software ===
python : 3.8.2
fastai : 2.4
fastcore : 1.3.20
fastprogress : 0.2.7
torch : 1.9.0
nvidia driver : 450.51
torch cuda : 10.2 / is available
torch cudnn : 7605 / is enabled
=== Hardware ===
nvidia gpus : 8
torch devices : 2
- gpu0 : Tesla K80
- gpu1 : Tesla K80
=== Environment ===
platform : Linux-3.10.0-1127.19.1.el7.x86_64-x86_64-with-glibc2.10
distro : #1 SMP Tue Aug 25 17:23:54 UTC 2020
conda env : dl
python : /home/usr/.conda/envs/dl/bin/python
sys.path :
/home/usr/.conda/envs/dl/lib/python38.zip
/home/usr/.conda/envs/dl/lib/python3.8
/home/usr/.conda/envs/dl/lib/python3.8/lib-dynload
/home/usr/.local/lib/python3.8/site-packages
/home/usr/.conda/envs/dl/lib/python3.8/site-packages