Image segmentation unet: Expected input batch_size (3584) to match target batch_size (401408)

I am working on a segmentation task for medical images. I have a set of training MR images (QSM) along with ground truth segmentations.

Since these are retrieved from 3D NIfTI files rather than bitmap image files, I have written my own PyTorch dataset class for this:

class QSM_2D_With_Seg(torch.utils.data.Dataset):
    def __init__(self, sample_details, transform=None):
        self.sample_details = sample_details
        self.transform = transform
        self.c = 2

    def __len__(self):
        return len(self.sample_details)

    def __getitem__(self, idx):
        # convert idx to list if tensor
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # convert idx to image and slice numbers
        qsm_path, seg_path, slice_id = self.sample_details[idx]

        # load data and scale from estimated range of -8,+8 to 0,1
        qsm = (nib.load(qsm_path).get_fdata()[:,:,int(slice_id)]+8) / (8*2)
        seg = (nib.load(seg_path).get_fdata()[:,:,int(slice_id)])

        # resize images to common size
        qsm = torch.Tensor(cv2.resize(qsm, dsize=(224, 224)))
        seg = torch.Tensor(cv2.resize(seg, dsize=(224, 224)))

        # expand QSM over 3 channels so it works with RGB models
        qsm = qsm.expand(3, 224, 224)
        
        # apply any necessary transformations
        if self.transform:
            qsm = self.transform(qsm)
            seg = self.transform(seg)

        return fastai.torch_core.TensorImage(qsm), fastai.torch_core.TensorMask(seg)

    def __iter__(self):
        for idx in range(len(self.sample_details)):
            yield self.__getitem__(idx)

I later create a dataloaders object and learner:

# create dataloaders from datasets
train_ds = QSM_2D_With_Seg(train_samples)
valid_ds = QSM_2D_With_Seg(valid_samples)
dls = fastai.data.core.DataLoaders.from_dsets(train_ds, valid_ds, batch_size=8, device='cuda:0')

# build a unet learner from dls and arch
learn = fastai.vision.learner.unet_learner(
    dls=dls,                            # data loader
    arch=fastai.vision.models.resnet34, # model architecture
    loss_func=fastai.losses.CrossEntropyLossFlat(), # loss function for evaluation during training
    model_dir='models',                  # save directory for trained model
)

But, when I try to train the model or get some dummy predictions, I get the error from the title:

inp, pred, target = learn.get_preds(with_input = True)


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/clusterdata/uqaste15/data/prostate/prostate.ipynb Cell 19 in ()
----> 1 inp, pred, target = learn.get_preds(with_input = True)

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:243, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    241 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    242 with ContextManagers(ctx_mgrs):
--> 243     self._do_epoch_validate(dl=dl)
    244     if act is None: act = getattr(self.loss_func, 'activation', noop)
    245     res = cb.all_tensors()

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:193, in Learner._do_epoch_validate(self, ds_idx, dl)
    191 if dl is None: dl = self.dls[ds_idx]
    192 self.dl = dl
--> 193 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:160, in Learner._with_events(self, f, event_type, ex, final)
    159 def _with_events(self, f, event_type, ex, final=noop):
--> 160     try: self(f'before_{event_type}');  f()
    161     except ex: self(f'after_cancel_{event_type}')
    162     self(f'after_{event_type}');  final()

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:166, in Learner.all_batches(self)
    164 def all_batches(self):
    165     self.n_iter = len(self.dl)
--> 166     for o in enumerate(self.dl): self.one_batch(*o)

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:184, in Learner.one_batch(self, i, b)
    182 self.iter = i
    183 self._split(b)
--> 184 self._with_events(self._do_one_batch, 'batch', CancelBatchException)

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:160, in Learner._with_events(self, f, event_type, ex, final)
    159 def _with_events(self, f, event_type, ex, final=noop):
--> 160     try: self(f'before_{event_type}');  f()
    161     except ex: self(f'after_cancel_{event_type}')
    162     self(f'after_{event_type}');  final()

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:172, in Learner._do_one_batch(self)
    170 self('after_pred')
    171 if len(self.yb):
--> 172     self.loss_grad = self.loss_func(self.pred, *self.yb)
    173     self.loss = self.loss_grad.clone()
    174 self('after_loss')

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/losses.py:35, in BaseLoss.__call__(self, inp, targ, **kwargs)
     33 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     34 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/modules/module.py:727, in Module._call_impl(self, *input, **kwargs)
    725     result = self._slow_forward(*input, **kwargs)
    726 else:
--> 727     result = self.forward(*input, **kwargs)
    728 for hook in itertools.chain(
    729         _global_forward_hooks.values(),
    730         self._forward_hooks.values()):
    731     hook_result = hook(self, input, result)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/modules/loss.py:961, in CrossEntropyLoss.forward(self, input, target)
    960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961     return F.cross_entropy(input, target, weight=self.weight,
    962                            ignore_index=self.ignore_index, reduction=self.reduction)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2462, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2460     tens_ops = (input, target)
   2461     if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462         return handle_torch_function(
   2463             cross_entropy, tens_ops, input, target, weight=weight,
   2464             size_average=size_average, ignore_index=ignore_index, reduce=reduce,
   2465             reduction=reduction)
   2466 if size_average is not None or reduce is not None:
   2467     reduction = _Reduction.legacy_get_string(size_average, reduce)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/overrides.py:1063, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1059 # Call overrides
   1060 for overloaded_arg in overloaded_args:
   1061     # Use `public_api` instead of `implementation` so __torch_function__
   1062     # implementations can do equality/identity comparisons.
-> 1063     result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1065     if result is not NotImplemented:
   1066         return result

File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/torch_core.py:325, in TensorBase.__torch_function__(self, func, types, args, kwargs)
    323 convert=False
    324 if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325 res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    326 if convert: res = convert(res)
    327 if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/tensor.py:995, in Tensor.__torch_function__(cls, func, types, args, kwargs)
    992     return NotImplemented
    994 with _C.DisableTorchFunction():
--> 995     ret = func(*args, **kwargs)
    996     return _convert(ret, cls)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2468, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466 if size_average is not None or reduce is not None:
   2467     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2261, in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2258     raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
   2260 if input.size(0) != target.size(0):
-> 2261     raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
   2262                      .format(input.size(0), target.size(0)))
   2263 if dim == 2:
   2264     ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (3584) to match target batch_size (401408).

I include the full notebook and outputs here.

Any idea what could be causing this or how I could fix it? Any help would be greatly appreciated!

Hey,
those mismatches in batch sizes most likely come from problems with the loss function.

I’ve done a quick investigation and in the docs segmentation example, if you use the infered loss function it has the axis-value of Cross Entropy set to 1 instead of the default -1.
This also matches the shapes the error returns since 3584 = 8*2*224, missing the last channel (axis = -1) and 401.408=8*224*224 missing the second channel (expecting axis=1). So maybe

learn = fastai.vision.learner.unet_learner(
    dls=dls,                            # data loader
    arch=fastai.vision.models.resnet34, # model architecture
    loss_func=fastai.losses.CrossEntropyLossFlat(axis=1), # loss function for evaluation during training
    model_dir='models',                  # save directory for trained model
)

gets you going.
Please come back if that didn’t work or you need further help, this seems like an amazing project, hope you have a great impact with that!

2 Likes

Ah, thank you so much for your assistance! Your advice progressed me to another error about the data type of the target being Float when it expected Long, but this is because it was resizing the segmentation without using nearest neighbours interpolation. After fixing this next step, the problem is solved. :slight_smile:

1 Like