Test_dl throwing expected 4D input (got 3D input) original dl batching correctly

Hi,

I am receiving the following error when trying to run get_preds.

probs, pred, idxs = learn.get_preds(dl=test_dl, with_decoded=True, act=activation_function)

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [48], in <cell line: 2>()
      1 #Expecting 4d got 3d, need another dimension for the batch size
----> 2 probs, pred, idxs = learn.get_preds(dl=test_dl, with_decoded=True, act=activation_function)

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:300, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    298 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    299 with ContextManagers(ctx_mgrs):
--> 300     self._do_epoch_validate(dl=dl)
    301     if act is None: act = getcallable(self.loss_func, 'activation')
    302     res = cb.all_tensors()

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:236, in Learner._do_epoch_validate(self, ds_idx, dl)
    234 if dl is None: dl = self.dls[ds_idx]
    235 self.dl = dl
--> 236 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:199, in Learner.all_batches(self)
    197 def all_batches(self):
    198     self.n_iter = len(self.dl)
--> 199     for o in enumerate(self.dl): self.one_batch(*o)

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:227, in Learner.one_batch(self, i, b)
    225 b = self._set_device(b)
    226 self._split(b)
--> 227 self._with_events(self._do_one_batch, 'batch', CancelBatchException)

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
    192 def _with_events(self, f, event_type, ex, final=noop):
--> 193     try: self(f'before_{event_type}');  f()
    194     except ex: self(f'after_cancel_{event_type}')
    195     self(f'after_{event_type}');  final()

File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:205, in Learner._do_one_batch(self)
    204 def _do_one_batch(self):
--> 205     self.pred = self.model(*self.xb)
    206     self('after_pred')
    207     if len(self.yb):

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.9/dist-packages/fastai/layers.py:408, in SequentialEx.forward(self, x)
    406 for l in self.layers:
    407     res.orig = x
--> 408     nres = l(res)
    409     # We have to remove res.orig to avoid hanging refs and therefore memory leaks
    410     res.orig, nres.orig = None, None

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py:139, in Sequential.forward(self, input)
    137 def forward(self, input):
    138     for module in self:
--> 139         input = module(input)
    140     return input

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/batchnorm.py:135, in _BatchNorm.forward(self, input)
    134 def forward(self, input: Tensor) -> Tensor:
--> 135     self._check_input_dim(input)
    137     # exponential_average_factor is set to self.momentum
    138     # (when it is available) only so that it gets updated
    139     # in ONNX graph when this node is exported to ONNX.
    140     if self.momentum is None:

File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/batchnorm.py:407, in BatchNorm2d._check_input_dim(self, input)
    405 def _check_input_dim(self, input):
    406     if input.dim() != 4:
--> 407         raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

ValueError: expected 4D input (got 3D input)

I have looked at the following and see the discrepancy.
Training dataloader:

batch = dl.one_batch()
for i in range(len(batch)):
    print(batch[i].shape)

Output:

torch.Size([4, 8, 256, 256])
torch.Size([4, 256, 256])

After successfully training the unet learner.
I grab the test_dl from the learner.

test_dl = learn.dls.test_dl(test_files, bs=4)

test_batch = test_dl.one_batch()
for i in range(len(test_batch)):
    print(test_batch[i].shape)

Output:

torch.Size([8, 256, 256])
torch.Size([8, 256, 256])
torch.Size([8, 256, 256])
torch.Size([8, 256, 256])

I have attempted to set create_batch in the test_dl like so:

def test_batch(xs):
    stacked_xs = torch.stack(xs)
    return stacked_xs

test_dl = learn.dls.test_dl(test_files[:6], bs=4, verbose=True, create_batch=test_batch)

But I end up with the same error message. I’m not sure why it’s not batching the test results up.

Here’s the code for the dataloader and learner.

class SmashTransform(ItemTransform):
    def encodes(self, inp):
        xs = inp[0]
        x_smash = None
        for idx, x in enumerate(xs):
            if x_smash == None:
                x_smash = x
            else:
                x_smash = torch.cat((x_smash, x))
        
        if len(inp) == 2:
            return x_smash, inp[1]
        else:
            return x_smash

def get_dataloader(file_set, months=MONTHS):
    blocks = []

    blocks.append(TransformBlock([partial(get_all_files_month, months=months)]))
    blocks.append(TransformBlock([label_func, partial(open_tif, cls=TensorMask)]))
    db = DataBlock(blocks=tuple(blocks),
               splitter=RandomSplitter(valid_pct=0.2, seed=42),
                   n_inp=1,
                   item_tfms=[SmashTransform()]
                   #get_x=[partial(get_filenames, month=month) for month in months]
              )
    ds = db.datasets(source=file_set)
    dl = db.dataloaders(source=file_set, bs=4)
    return dl

def get_learner(dl, arch=resnet34, no_months=6):
    return unet_learner(dl, 
                         arch, 
                         n_in=4 * no_months, 
                         n_out=1, 
                         loss_func=BCEWithLogitsLossFlat(), 
                         #splitter=sat_splitter
                         metrics=[Recall(axis=1), F1Score(axis=1)]
                       )

I was able to find a workaround by:

test_dl = learn.dls.test_dl(test_files[:6], bs=4, before_batch=test_batch)

I would still like to understand what the training dataloader was doing that the test dataloader wasn’t. From the course, it seems we shouldn’t have to do a workaround like this.

The workaround led to another issue, which I hope someone can help with.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [27], in <cell line: 1>()
----> 1 test_batch = test_dl.one_batch()
      2 print(test_batch.tfms)
      3 #TODO need to add a batch dimension to the tensors

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:172, in DataLoader.one_batch(self)
    170 def one_batch(self):
    171     if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
--> 172     with self.fake_l.no_multiproc(): res = first(self)
    173     if hasattr(self, 'it'): delattr(self, 'it')
    174     return res

File /usr/local/lib/python3.9/dist-packages/fastcore/basics.py:660, in first(x, f, negate, **kwargs)
    658 x = iter(x)
    659 if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)
--> 660 return next(x, None)

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:127, in DataLoader.__iter__(self)
    125 self.before_iter()
    126 self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
--> 127 for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
    128     # pin_memory causes tuples to be converted to lists, so convert them back to tuples
    129     if self.pin_memory and type(b) == list: b = tuple(b)
    130     if self.device is not None: b = to_device(b, self.device)

File /usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:652, in _BaseDataLoaderIter.__next__(self)
    649 if self._sampler_iter is None:
    650     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    651     self._reset()  # type: ignore[call-arg]
--> 652 data = self._next_data()
    653 self._num_yielded += 1
    654 if self._dataset_kind == _DatasetKind.Iterable and \
    655         self._IterableDataset_len_called is not None and \
    656         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:692, in _SingleProcessDataLoaderIter._next_data(self)
    690 def _next_data(self):
    691     index = self._next_index()  # may raise StopIteration
--> 692     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    693     if self._pin_memory:
    694         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py:39, in _IterableDatasetFetcher.fetch(self, possibly_batched_index)
     37         raise StopIteration
     38 else:
---> 39     data = next(self.dataset_iter)
     40 return self.collate_fn(data)

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:138, in DataLoader.create_batches(self, samps)
    136 if self.dataset is not None: self.it = iter(self.dataset)
    137 res = filter(lambda o:o is not None, map(self.do_item, samps))
--> 138 yield from map(self.do_batch, self.chunkify(res))

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:168, in DataLoader.do_batch(self, b)
--> 168 def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:164, in DataLoader.create_batch(self, b)
    163 def create_batch(self, b): 
--> 164     try: return (fa_collate,fa_convert)[self.prebatched](b)
    165     except Exception as e: 
    166         if not self.prebatched: collate_error(e,b)

File /usr/local/lib/python3.9/dist-packages/fastai/data/load.py:51, in fa_collate(t)
     49 "A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s"
     50 b = t[0]
---> 51 return (default_collate(t) if isinstance(b, _collate_types)
     52         else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
     53         else default_collate(t))

File /usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/collate.py:141, in default_collate(batch)
    139         storage = elem.storage()._new_shared(numel, device=elem.device)
    140         out = elem.new(storage).resize_(len(batch), *list(elem.size()))
--> 141     return torch.stack(batch, 0, out=out)
    142 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
    143         and elem_type.__name__ != 'string_':
    144     if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
    145         # array of string classes and object

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not TensorImage

I would guess it has to do with something me messing with the batching?

I seem to be going around and around here. I think the second error was an issue with what I added and not a solution to the original problem.

I found the issue. The Smash Transform was returning a single value from the encodes method when there was only 1 input.

The return value needs to be wrapped in a tuple.
The fix:

if len(inp) == 2:
        return x_smash, inp[1]
else:
        return (x_smash,) # <--- Fix