Fastai v2 Recipes (Tips and Tricks) - Wiki

split_idx, or how to selectively apply a transform to train and valid (test) datasets?

First of all, I would like to credit @jeremy for sharing the 10 episodes of his fastai2 daily code walk-thrus, and @sgugger for answering questions on the forum as wall as @arora_aman, @akashpalrecha, and @init_27 for posting their code walk-thrus.

split_idx allows a given Transform to be applied to a specific data subset: train, valid and test datasets. For example, it enables a specific image augmentation to be applied to only the train dataset and not to the valid dataset.

By setting your transform split_idx flag, you can make your transform to be applied :

  • to both train and valid (test) datasets if you set (leave) your transform split_idx to None
  • or only to the train dataset (and not to the valid (test) datatset) if you set your transform split_idx to 0
  • or only to the valid (test) dataset (and not to the train datatset) if you set your transform split_idx to 1

In which files split_idx can be found?
split_idx can be found:
1- in fastai2.data.core.py: split_idx is used by both TfmLists, and Datasets classes (Datasets uses TfmLists objects) where:
○ the train dataset has a split_idx=0,
○ the valid dataset has a split_idx=1,
○ the test dataset also has a split_idx=1,
○ There is also the set_split_idx() method that sets a split_idx to a given dataset. That method is used in the “test time augmentattion” tta() method found in fast2.learner.py

2- in TfmDL: split_idx is used by the before_iter() method in order to set split_idx of each batch_tfms Pipeline objects to the same split_idx as the corresponding dataset (train and valid datasets)

3- in fastai2.vision.augment.py: split_idx is used by several Transform classes as shown further below (e.g. RandTransform, and Resize Transforms)

4- also in fast2.learner.py: it is used by the “test time augmentattion” tta() method

How does it work?
A Transform has a split_idx attribute and defines the following _call () method:

def _call(self, fn, x, split_idx=None, **kwargs):
        if split_idx!=self.split_idx and self.split_idx is not None: return x
        return self._do_call(getattr(self, fn), x, **kwargs)

As you might notice, we pass a split_idx argument to the _call() method. That split_idx argument is checked against the Transform self.split_idx in the if statement. The latter sets the behavior of the Transform as summarized here above.

We generally don’t explicitly call a Transform. A Pipeline which is a class that store a list of Transform objects is responsible of calling each one of its Transform objects _call() method.

Pipeline are used in the TfmLists class (and Datasets class because the latter uses TfmLists objects). Pipeline also store a split_idx as an attribute. Both Datasets and TfmLists generally have a train dataset (with a split_idx=0), a valid dataset (with a split_idx=1), and sometimes a test dataset (also with a split_idx=1). When a Pipeline of Transform is applied to one of the 3 datasets, the Pipeline call each of its Transform objects by passing the split_idx of the corresponding dataset that we are about to transform.

Therefore, if we are transforming a train dataset, the Pipeline passes split_idx=0 to each of its Transform objects _call() method. Similarly, for both valid dataset and test dataset, the Pipeline passes split_idx=1 to each of its Transform objects _call() method.

Now, back to our Transform _call() method. The latter will compare the passed argument (from the Pipeline being the dataset split_idx) to its self.split_idx (self is the Transform object), and decides either to ignore the call by returning the input x without any change, or apply the transform through the return self._do_call(getattr(self, fn), x, **kwargs) by following the rules mentioned here above.

Let’s check some Transform examples:

Transform examples in augment.py:

Resize Transform

class Resize(RandTransform):
    split_idx = None
    mode,mode_mask,order,final_size = Image.BILINEAR,Image.NEAREST,1,None
    "Resize image to `size` using `method`"

Resize has a split_idx=None meaning that it will be applied to both the train and valid (test) datasets.

RandTransform

class RandTransform(Transform):
    "A transform that before_call its state at each `__call__`"
    do,nm,supports,split_idx = True,None,[],0

RandTransform has a split_idx=0 meaning that it will only be applied to the train dataset. Be aware, if the transform is not applied to a given item it is because the transform is applied with a given probability p (meaning not all the train dataset items are transformed).

Test Time Augmentation tta() method
split_idx is also used in learner.py tta() method. Test time augmentation can significantly improves accuracy.

tta() combines predictions of several augmented images. To calculate the prediction of a given image, we follow the steps shown here below (assuming we are using the default n=4 value for the number of the augmented images):

1- First, we create 4 augmented images using the train dataset Pipeline Transforms. This is why the dl.dataset.set_split_idx(0) is called in order to make sure the Pipeline objects passes split_idx=0 to each Transform call. Each augmented image gets its prediction. The aug_preds, representing predictions of the 4 images, is then reduced using either the max or the mean value,

2- Then, we calculate only one prediction (preds) using the valid (test) dataset Pipeline of Transform. The use of dl.dataset.set_split_idx(1) ensures to apply only the Pipeline Transforms that is set for the valid (test) dataset: split_idx=1,

3- Finally, we combine aug_preds and preds using either the max or a linear interpolation function.

7 Likes