Input-Training batch size error for Multi-Class Classification with Custom Dataset

I implemented a custom Dataset for multi-class classification but get error:

ValueError: Expected input batch_size (64) to match target batch_size (640).

I noticed that the target batch size is off by the number of classes, so I’m assuming that I’m missing something obvious with in the __getitem__ function. I don’t know how to diagnose the problem or how to tweak the returns from my custom Dataset class. Anyone recommendations for how to proceed?

Here’s some other data points. In custom Dataset class:
.n_inp: 1
.c: 10
getitem return: ( torch.Size([3, 8, 8]), torch.Size([1, 10]) )
Label: torch.Size([1, 10])

Other:
batch size: 64

Code snippets:

class MyDataset(Dataset):
    """My dataset."""

    def __init__(self, csv_file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations and FEN.
        """
        self.df = pd.read_csv(csv_file)
        
        # create the labels
        mlb = MultiLabelBinarizer()
        res = pd.DataFrame(mlb.fit_transform(self.df.Labels.apply(lambda x: x.split(' ')).tolist()),
                          columns=mlb.classes_,
                          index=self.df.index)
        self.df = pd.concat([self.df, res], axis=1)
        self.c = len(mlb.classes_)
        self.classes = mlb.classes_
        self.n_inp = 1

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = createImg(self.df.iloc[idx]['inputString'])
        output = transforms.ToTensor()(img).unsqueeze_(0).squeeze().float()
        labels = torch.as_tensor(self.df.iloc[idx][self.classes].tolist()).unsqueeze_(0)
        
        return (output, labels)
data = MyDataset(dtapath)

train_loader = DataLoader(data, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = DataLoader(data, batch_size=batch_size,
                                                sampler=valid_sampler)

dls = DataLoaders(train_loader, validation_loader, device='cpu')


learn = cnn_learner(dls,resnet50, loss_func=CrossEntropyLossFlat(), n_out=len(validThemes), normalize=False, metrics=[accuracy_multi])
learn.lr_find()

Full stacktrace below:

ValueError                                Traceback (most recent call last)
<ipython-input-136-d81c6bd29d71> in <module>()
----> 1 learn.lr_find()

20 frames
/usr/local/lib/python3.6/dist-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
    220     n_epoch = num_it//len(self.dls.train) + 1
    221     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222     with self.no_logging(): self.fit(n_epoch, cbs=cb)
    223     if show_plot: self.recorder.plot_lr_find()
    224     if suggestions:

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    209             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    210             self.n_epoch = n_epoch
--> 211             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    212 
    213     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_fit(self)
    200         for epoch in range(self.n_epoch):
    201             self.epoch=epoch
--> 202             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    203 
    204     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch(self)
    194 
    195     def _do_epoch(self):
--> 196         self._do_epoch_train()
    197         self._do_epoch_validate()
    198 

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_epoch_train(self)
    186     def _do_epoch_train(self):
    187         self.dl = self.dls.train
--> 188         self._with_events(self.all_batches, 'train', CancelTrainException)
    189 
    190     def _do_epoch_validate(self, ds_idx=1, dl=None):

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in all_batches(self)
    164     def all_batches(self):
    165         self.n_iter = len(self.dl)
--> 166         for o in enumerate(self.dl): self.one_batch(*o)
    167 
    168     def _do_one_batch(self):

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in one_batch(self, i, b)
    182         self.iter = i
    183         self._split(b)
--> 184         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    185 
    186     def _do_epoch_train(self):

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

/usr/local/lib/python3.6/dist-packages/fastai/learner.py in _do_one_batch(self)
    170         self('after_pred')
    171         if len(self.yb):
--> 172             self.loss_grad = self.loss_func(self.pred, *self.yb)
    173             self.loss = self.loss_grad.clone()
    174         self('after_loss')

/usr/local/lib/python3.6/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     36 
     37 # Cell

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
    960     def forward(self, input: Tensor, target: Tensor) -> Tensor:
    961         return F.cross_entropy(input, target, weight=self.weight,
--> 962                                ignore_index=self.ignore_index, reduction=self.reduction)
    963 
    964 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2463                 cross_entropy, tens_ops, input, target, weight=weight,
   2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,
-> 2465                 reduction=reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)

/usr/local/lib/python3.6/dist-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1061         # Use `public_api` instead of `implementation` so __torch_function__
   1062         # implementations can do equality/identity comparisons.
-> 1063         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1064 
   1065         if result is not NotImplemented:

/usr/local/lib/python3.6/dist-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    323         convert=False
    324         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    326         if convert: res = convert(res)
    327         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

/usr/local/lib/python3.6/dist-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
    993 
    994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
    996             return _convert(ret, cls)
    997 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2260     if input.size(0) != target.size(0):
   2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
   2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (64) to match target batch_size (640).

Could you try retrieving single batch from your dataloader and check the shapes before the training? Also better to call show_batch to make sure everything looks okay to you?

Check the below for the example

64 (intended batch size) x 10 (# of classes) = 640. So it appears as if the ys are not being batched correctly.

While show_batch had reasonable looking dimensions, checking the shapes revealed the mismatch. Thanks for sharing this article!