Help with Error: input batch size does not match target batch size

Hello, all, I have done some searching and it seems most of the threads regarding similar issues on this forum have not come to a solution.

I am trying to create an image segmentation model using .jpg files for both the input images and their corresponding masks. Below is the code that I am using to try to get something working:

from fastai.vision.all import *


def get_masks(img_path):
    return Path(f"data/MASKS/CM-{img_path.stem}{img_path.suffix}")


def main():
    path = Path(f'data/IMAGES')
    codes = ['n', 'y']
    block = (ImageBlock, MaskBlock(codes))
    dblock = DataBlock(blocks=block,
                       get_items=get_image_files,
                       get_y=get_masks,
                       splitter=RandomSplitter(),
                       item_tfms=Resize(192),
                       batch_tfms=aug_transforms())

    dls = dblock.dataloaders(path)
    learn = vision_learner(dls, resnet34, metrics=accuracy)
    learn.fine_tune(3)


if __name__ == "__main__":
    main()

Here is the error that I get when running this code:

epoch     train_loss  valid_loss  accuracy  time    
Traceback (most recent call last):
  File "/Users/adamslay/Documents/dev/practice/img_seg/fast.py", line 25, in <module>
    main()
  File "/Users/adamslay/Documents/dev/practice/img_seg/fast.py", line 21, in main
    learn.fine_tune(3)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py", line 165, in fine_tune
    self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py", line 119, in fit_one_cycle
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 256, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 193, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 245, in _do_fit
    self._with_events(self._do_epoch, 'epoch', CancelEpochException)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 193, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 239, in _do_epoch
    self._do_epoch_train()
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 231, in _do_epoch_train
    self._with_events(self.all_batches, 'train', CancelTrainException)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 193, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 199, in all_batches
    for o in enumerate(self.dl): self.one_batch(*o)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 227, in one_batch
    self._with_events(self._do_one_batch, 'batch', CancelBatchException)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 193, in _with_events
    try: self(f'before_{event_type}');  f()
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/learner.py", line 208, in _do_one_batch
    self.loss_grad = self.loss_func(self.pred, *self.yb)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/losses.py", line 54, in __call__
    return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1150, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2832, in cross_entropy
    return handle_torch_function(
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/overrides.py", line 1355, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/fastai/torch_core.py", line 378, in __torch_function__
    res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/_tensor.py", line 1051, in __torch_function__
    ret = func(*args, **kwargs)
  File "/Users/adamslay/opt/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2846, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (64) to match target batch_size (2359296).

Process finished with exit code 1

Here is an example of a .jpg ffrom the IMAGES directory:

and a .jpg from the MASKS directory:

Any help regarding this issue would be appreciated.

Still working on this. Could anyone help me understand what the “target batch_size” refers to and where that number comes from? I’m having a hard time deciphering where that number originates from, so I’m not sure how to go about reconciling the difference between it and the “input batch_size”

Usually you would use a Unet for segmentation tasks, see the docs tutorial. Right now you are using a standard image classification model that probably has an output layer of size len(codes) and you are trying to compare the output of that to 192*192 masks.