Issue with Calculating Per-Item Losses using Segmentation Models

I’m currently working on calculating per-item losses using a segmentation model, but I’m facing an issue when using the PyTorch segmentation models library instead of fastai’s built-in unet_learner . Here’s the code snippet I’m using:

dls = block.dataloaders(path/‘seg_train/images’, bs=4)
model = smp.UnetPlusPlus(
encoder_name=‘efficientnet-b5’,
encoder_weights=‘imagenet’,
in_channels=3,
classes=num_classes,
activation=None
)
learn = Learner(dls, model, loss_func=criterion, metrics=[foreground_acc, DiceMulti()], opt_func=ranger).to_fp16()
learn.load(‘deeplabv3plus_efficientnetb5_augs_v4’)

However, when I try to retrieve predictions and losses using learn.get_preds with with_loss=True , I encounter the following error:

probs,targs,losses=learn.get_preds(ds_idx=1,with_loss=True)


RuntimeError Traceback (most recent call last)
Cell In [5], line 1
----> 1 probs,targs,losses=learn.get_preds(ds_idx=1,with_loss=True)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:308, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
306 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
307 with ContextManagers(ctx_mgrs):
→ 308 self._do_epoch_validate(dl=dl)
309 if act is None: act = getcallable(self.loss_func, ‘activation’)
310 res = cb.all_tensors()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:244, in Learner._do_epoch_validate(self, ds_idx, dl)
242 if dl is None: dl = self.dls[ds_idx]
243 self.dl = dl
→ 244 with torch.no_grad(): self._with_events(self.all_batches, ‘validate’, CancelValidException)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner.with_events(self, f, event_type, ex, final)
198 def with_events(self, f, event_type, ex, final=noop):
→ 199 try: self(f’before
{event_type}'); f()
200 except ex: self(f’after_cancel
{event_type}‘)
201 self(f’after_{event_type}’); final()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:205, in Learner.all_batches(self)
203 def all_batches(self):
204 self.n_iter = len(self.dl)
→ 205 for o in enumerate(self.dl): self.one_batch(*o)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:235, in Learner.one_batch(self, i, b)
233 b = self._set_device(b)
234 self._split(b)
→ 235 self._with_events(self._do_one_batch, ‘batch’, CancelBatchException)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:201, in Learner.with_events(self, f, event_type, ex, final)
199 try: self(f’before
{event_type}‘); f()
200 except ex: self(f’after_cancel_{event_type}’)
→ 201 self(f’after_{event_type}'); final()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:172, in Learner.call(self, event_name)
→ 172 def call(self, event_name): L(event_name).map(self._call_one)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/foundation.py:156, in L.map(self, f, *args, **kwargs)
→ 156 def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs))

File ~/mambaforge/lib/python3.10/site-packages/fastcore/basics.py:840, in map_ex(iterable, f, gen, *args, **kwargs)
838 res = map(g, iterable)
839 if gen: return res
→ 840 return list(res)

File ~/mambaforge/lib/python3.10/site-packages/fastcore/basics.py:825, in bind.call(self, *args, **kwargs)
823 if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
824 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
→ 825 return self.func(*fargs, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:176, in Learner._call_one(self, event_name)
174 def _call_one(self, event_name):
175 if not hasattr(event, event_name): raise Exception(f’missing {event_name}')
→ 176 for cb in self.cbs.sorted(‘order’): cb(event_name)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/core.py:62, in Callback.call(self, event_name)
60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
—> 62 except Exception as e: raise modify_exception(e, f’Exception occured in {self.__class__.__name__} when calling event {event_name}:\n\t{e.args[0]}', replace=True)
63 if event_name==‘after_fit’: self.run=True #Reset self.run to True at each end of fit
64 return res

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/core.py:60, in Callback.call(self, event_name)
58 res = None
59 if self.run and _run:
—> 60 try: res = getcallable(self, event_name)()
61 except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
62 except Exception as e: raise modify_exception(e, f’Exception occured in {self.__class__.__name__} when calling event {event_name}:\n\t{e.args[0]}', replace=True)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/core.py:149, in GatherPredsCallback.after_batch(self)
147 if self.with_loss:
148 bs = find_bs(self.yb)
→ 149 loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
150 self.losses.append(self.learn.to_detach(loss))

File ~/mambaforge/lib/python3.10/site-packages/fastai/torch_core.py:382, in TensorBase.torch_function(cls, func, types, args, kwargs)
380 if cls.debug and func.name not in (‘str’,‘repr’): print(func, types, args, kwargs)
381 if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)
→ 382 res = super().torch_function(func, types, args, ifnone(kwargs, {}))
383 dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))
384 if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_copy=True)

File ~/mambaforge/lib/python3.10/site-packages/torch/_tensor.py:1278, in Tensor.torch_function(cls, func, types, args, kwargs)
1275 return NotImplemented
1277 with _C.DisableTorchFunction():
→ 1278 ret = func(*args, **kwargs)
1279 if func in get_default_nowrap_functions():
1280 return ret

RuntimeError: Exception occured in GatherPredsCallback when calling event after_batch:
shape ‘[4, -1]’ is invalid for input of size 1

I would greatly appreciate any insights into resolving this issue. Thank you in advance for your help!

Warm regards,
Bilal

1 Like

I’ve made progress in resolving the issue I encountered with my custom loss function (criterion) in my Fastai segmentation model. I realised that the issue lies in the custom loss function I used to train the model. Interestingly, when I switch loss_func of the learner to FocalLossFlat(axis=1), the code worked and has started to generate loss per item as expected.

Here’s a snippet of the relevant code defining the custom loss:

TverskyLoss = smp.losses.TverskyLoss(mode=‘multilabel’, log_loss=False)

def tversky_loss(y_pred, y_true):
y_true = F.one_hot(y_true, num_classes=num_classes).permute(0, 3, 1, 2).float()
return TverskyLoss(y_pred, y_true)

def criterion(y_pred, y_true,**kwargs):
return 0.5FocalLossFlat(axis=1)(y_pred, y_true) + 0.5tversky_loss(y_pred, y_true)

Any guidance to make the custom loss function to work with learn.get_preds(ds_idx=0, with_loss=True) would be greatly appreciated.

Thanks in advance.

1 Like

Fastai losses are more complex than Pytorch loss functions. Search for BaseLoss here and in the docs for info

1 Like

What I just found in the docs: " Note: If you want to use the option with_loss=True on a custom loss function, make sure you have implemented a reduction attribute that supports ‘none’"

Source: fastai - Learner, Metrics, Callbacks

I think smp does not have that.

2 Likes

Thank you, @Archaeologist, for your insights. I’ve delved into the BaseLoss, and it has certainly enriched my understanding of fastai’s loss functions.

I’m considering the idea of creating a wrapper function around the PyTorch segmentation loss, integrating the necessary reduction option for the loss=True option to work with get_preds(). What are your thoughts on this approach?

This repository seems to already have a fastai loss wrapper. Have a look if you like:

2 Likes

That is incredible. Thanks a lot for sharing the link. Much appreciated.

1 Like