Need help understanding DataBlock’s batch_tfms

I am completely confused as to what batch_tfms does. I need to have access to image/target tuples at the mini batch level. I thought that a setup such as batch_tfms=aug_tfms (where aug_tfms is a function that receives xb, yb as inputs) would work. Just like in the fasta1 dl_tfms parameter.

However, surprisingly I find the following:

  1. The function aug_tfms receives only one tensor xb as input

  2. It fires only once, at the first tuple of the mini batch

  3. When it fires, it cycles 3 times. The first reading gives me the image (in PIL format). The subsequent reading gives me the bounding box. And the third, the label. That’s it!

The end result is that batch_tfms gives me only a single tuple of the mini batch.

I am obviously not doing things right. Any guidance will be appreciated.

I have looked also at after_batch and before_batch, but these don’t help. I have tried to reverse engineer the aug_transforms() but got lost in the libraries that cannot be easily traversed.

Thanks for any help on the matter

Batch transforms run transforms on a batch, and their applied based on is occurring or what is available (ala typedispatch). For example, let’s look at bounding boxes. There is a transform called flip_lr, which flips an image left and right. If we check what this augmentation looks like, we see the following:

# Cell
def flip_lr(x:Image.Image): return x.transpose(Image.FLIP_LEFT_RIGHT)
def flip_lr(x:TensorImageBase): return x.flip(-1)
def flip_lr(x:TensorPoint): return TensorPoint(_neg_axis(x.clone(), 0))
def flip_lr(x:TensorBBox):  return TensorBBox(TensorPoint(x.view(-1,2)).flip_lr().view(-1,4))

What does this mean? It means if at some point in the pipeline, if either the x or the y are TensorBBox's or Image.Image's (PILImage, etc), then they have them applied. Else they do not. aug_transforms applies a number of these transforms to an x or y, and if they do not contain the property (that x declaration for that type in the pipeline), then they are not run on your y.

Another example is looking at the DataBlock for BBox. If we look at it we see:

BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad}). Which means when we have our batches, before do anything we pad our output, and along with this we run the PointScaler transform on our inputs.

Basically, it only takes in one parameter because we look at what the current state available to us is, rather than everything at once. @jeremy or @sgugger please correct me if my understanding is wrong in any way but that’s what I believe I’m understanding with the datatypes

Have a look at the source for DataBlock:

    def dataloaders(self, source, path='.', verbose=False, **kwargs):
        dsets = self.datasets(source)
        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}
        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)

So we see that batch_tfms is what DataLoader calls after_batch. So let’s look at DataLoader:

    def __iter__(self):
        for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
            if self.device is not None: b = to_device(b, self.device)
            yield self.after_batch(b)

So each batch is passed to after_batch. It is a Pipeline. I think the main thing you’re missing is the understanding of Transform and Pipeline and friends. Try this tutorial:

Note that Transform and Pipeline are the key abstractions on which all data processing in fastai2 is based. They’re not complicated, but they are different. So understanding the underlying concepts is important if you’re wanting to move beyond the pre-written applications and do things that are more custom.


Thank you for your guidance.

I understand from the code that transformations that can be done in GPUs are done in batch_tfms as data is pushed to the device before transformation. Whereas some transformations such as Random cropping which will crop the image at random locations for each item will run on CPU so we we pass them in item_tfms. Is my understanding correct.

Or is there a list of transforms which GPU supports and hence we do it in GPU passing it in batch_tfms.