I am working on a segmentation task for medical images. I have a set of training MR images (QSM) along with ground truth segmentations.
Since these are retrieved from 3D NIfTI files rather than bitmap image files, I have written my own PyTorch dataset class for this:
class QSM_2D_With_Seg(torch.utils.data.Dataset):
def __init__(self, sample_details, transform=None):
self.sample_details = sample_details
self.transform = transform
self.c = 2
def __len__(self):
return len(self.sample_details)
def __getitem__(self, idx):
# convert idx to list if tensor
if torch.is_tensor(idx):
idx = idx.tolist()
# convert idx to image and slice numbers
qsm_path, seg_path, slice_id = self.sample_details[idx]
# load data and scale from estimated range of -8,+8 to 0,1
qsm = (nib.load(qsm_path).get_fdata()[:,:,int(slice_id)]+8) / (8*2)
seg = (nib.load(seg_path).get_fdata()[:,:,int(slice_id)])
# resize images to common size
qsm = torch.Tensor(cv2.resize(qsm, dsize=(224, 224)))
seg = torch.Tensor(cv2.resize(seg, dsize=(224, 224)))
# expand QSM over 3 channels so it works with RGB models
qsm = qsm.expand(3, 224, 224)
# apply any necessary transformations
if self.transform:
qsm = self.transform(qsm)
seg = self.transform(seg)
return fastai.torch_core.TensorImage(qsm), fastai.torch_core.TensorMask(seg)
def __iter__(self):
for idx in range(len(self.sample_details)):
yield self.__getitem__(idx)
I later create a dataloaders object and learner:
# create dataloaders from datasets
train_ds = QSM_2D_With_Seg(train_samples)
valid_ds = QSM_2D_With_Seg(valid_samples)
dls = fastai.data.core.DataLoaders.from_dsets(train_ds, valid_ds, batch_size=8, device='cuda:0')
# build a unet learner from dls and arch
learn = fastai.vision.learner.unet_learner(
dls=dls, # data loader
arch=fastai.vision.models.resnet34, # model architecture
loss_func=fastai.losses.CrossEntropyLossFlat(), # loss function for evaluation during training
model_dir='models', # save directory for trained model
)
But, when I try to train the model or get some dummy predictions, I get the error from the title:
inp, pred, target = learn.get_preds(with_input = True)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/clusterdata/uqaste15/data/prostate/prostate.ipynb Cell 19 in ()
----> 1 inp, pred, target = learn.get_preds(with_input = True)
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:243, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
241 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
242 with ContextManagers(ctx_mgrs):
--> 243 self._do_epoch_validate(dl=dl)
244 if act is None: act = getattr(self.loss_func, 'activation', noop)
245 res = cb.all_tensors()
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:193, in Learner._do_epoch_validate(self, ds_idx, dl)
191 if dl is None: dl = self.dls[ds_idx]
192 self.dl = dl
--> 193 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:160, in Learner._with_events(self, f, event_type, ex, final)
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:166, in Learner.all_batches(self)
164 def all_batches(self):
165 self.n_iter = len(self.dl)
--> 166 for o in enumerate(self.dl): self.one_batch(*o)
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:184, in Learner.one_batch(self, i, b)
182 self.iter = i
183 self._split(b)
--> 184 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:160, in Learner._with_events(self, f, event_type, ex, final)
159 def _with_events(self, f, event_type, ex, final=noop):
--> 160 try: self(f'before_{event_type}'); f()
161 except ex: self(f'after_cancel_{event_type}')
162 self(f'after_{event_type}'); final()
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/learner.py:172, in Learner._do_one_batch(self)
170 self('after_pred')
171 if len(self.yb):
--> 172 self.loss_grad = self.loss_func(self.pred, *self.yb)
173 self.loss = self.loss_grad.clone()
174 self('after_loss')
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/losses.py:35, in BaseLoss.__call__(self, inp, targ, **kwargs)
33 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
34 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/modules/module.py:727, in Module._call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
730 self._forward_hooks.values()):
731 hook_result = hook(self, input, result)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/modules/loss.py:961, in CrossEntropyLoss.forward(self, input, target)
960 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961 return F.cross_entropy(input, target, weight=self.weight,
962 ignore_index=self.ignore_index, reduction=self.reduction)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2462, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2460 tens_ops = (input, target)
2461 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462 return handle_torch_function(
2463 cross_entropy, tens_ops, input, target, weight=weight,
2464 size_average=size_average, ignore_index=ignore_index, reduce=reduce,
2465 reduction=reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/overrides.py:1063, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1059 # Call overrides
1060 for overloaded_arg in overloaded_args:
1061 # Use `public_api` instead of `implementation` so __torch_function__
1062 # implementations can do equality/identity comparisons.
-> 1063 result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
1065 if result is not NotImplemented:
1066 return result
File ~/.local/lib/python3.8/site-packages/fastai-2.2.5-py3.8.egg/fastai/torch_core.py:325, in TensorBase.__torch_function__(self, func, types, args, kwargs)
323 convert=False
324 if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325 res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
326 if convert: res = convert(res)
327 if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/tensor.py:995, in Tensor.__torch_function__(cls, func, types, args, kwargs)
992 return NotImplemented
994 with _C.DisableTorchFunction():
--> 995 ret = func(*args, **kwargs)
996 return _convert(ret, cls)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2468, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2466 if size_average is not None or reduce is not None:
2467 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File ~/.local/lib/python3.8/site-packages/torch-1.7.0-py3.8-linux-x86_64.egg/torch/nn/functional.py:2261, in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2258 raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
2260 if input.size(0) != target.size(0):
-> 2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (3584) to match target batch_size (401408).
I include the full notebook and outputs here.
Any idea what could be causing this or how I could fix it? Any help would be greatly appreciated!