How to use transformations of my choice during inference with Test Time Augmentation?

I am using Test Time Augmentation during inference, like so-

file_path = '/path/to/file.jpg'

dl = learn_.dls.test_dl([file_path])
pred, _ = learn_.tta(dl=dl, n=N_IMAGES)

When I try to add additional transformations of my choice, I am unable to do so.

If I try to add additional transforms using either the item_tfms or batch_tfms parameters following the docs, like this-

pred, _ = learn_.tta(dl=dl,
                     n=N_IMAGES,
                     item_tfms=Resize(256),
                     batch_tfms=Zoom(p=1, draw=2.0))

I get thrown this error-

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class ‘fastai.vision.core.PILImage’>

Full Error Message
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-86c798126984> in <module>()
      1 # tta
      2 dl = learn_.dls.test_dl([file_path])
----> 3 pred, _ = learn_.tta(dl=dl, n=N_IMAGES, item_tfms=Resize(256), batch_tfms=Zoom(p=1, draw=2.0))
      4 cat = learn_.dls.vocab[torch.argmax(pred).item()]
      5 cat.lstrip()

9 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    423             # have message field
    424             raise self.exc_type(message=msg)
--> 425         raise self.exc_type(msg)
    426 
    427 

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
    data = next(self.dataset_iter)
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 118, in create_batches
    yield from map(self.do_batch, self.chunkify(res))
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 144, in do_batch
    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 143, in create_batch
    def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 50, in fa_collate
    else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 50, in <listcomp>
    else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
  File "/usr/local/lib/python3.7/dist-packages/fastai/data/load.py", line 51, in fa_collate
    else default_collate(t))
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 86, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'fastai.vision.core.PILImage'>

Is there any way I can use additional transformations during inference time with tta?

1 Like

I ran into the exact same issue, @rghosh did you find a solution?

1 Like

If you choose to override the transforms, youll need to make sure you have ToTensor(), IntToFloatTensor and a new Normalize with the stats of the old

1 Like