split_idx, or how to selectively apply a transform to train and valid (test) datasets?
First of all, I would like to credit @jeremy for sharing the 10 episodes of his fastai2 daily code walk-thrus, and @sgugger for answering questions on the forum as wall as @arora_aman, @akashpalrecha, and @init_27 for posting their code walk-thrus.
split_idx allows a given Transform to be applied to a specific data subset: train, valid and test datasets. For example, it enables a specific image augmentation to be applied to only the train dataset and not to the valid dataset.
By setting your transform
split_idx flag, you can make your transform to be applied :
- to both train and valid (test) datasets if you set (leave) your transform split_idx to None
- or only to the train dataset (and not to the valid (test) datatset) if you set your transform split_idx to 0
- or only to the valid (test) dataset (and not to the train datatset) if you set your transform split_idx to 1
In which files split_idx can be found?
split_idx can be found:
split_idx is used by both TfmLists, and Datasets classes (
TfmLists objects) where:
○ the train dataset has a
○ the valid dataset has a
○ the test dataset also has a
○ There is also the set_split_idx() method that sets a
split_idx to a given dataset. That method is used in the “test time augmentattion” tta() method found in
2- in TfmDL:
split_idx is used by the before_iter() method in order to set
split_idx of each
batch_tfms Pipeline objects to the same
split_idx as the corresponding dataset (train and valid datasets)
split_idx is used by several Transform classes as shown further below (e.g. RandTransform, and Resize Transforms)
4- also in
fast2.learner.py: it is used by the “test time augmentattion” tta() method
How does it work?
A Transform has a split_idx attribute and defines the following
_call () method:
def _call(self, fn, x, split_idx=None, **kwargs):
if split_idx!=self.split_idx and self.split_idx is not None: return x
return self._do_call(getattr(self, fn), x, **kwargs)
As you might notice, we pass a
split_idx argument to the
_call() method. That
split_idx argument is checked against the Transform
self.split_idx in the
if statement. The latter sets the behavior of the Transform as summarized here above.
We generally don’t explicitly call a Transform. A Pipeline which is a class that store a list of Transform objects is responsible of calling each one of its Transform objects
Pipeline are used in the TfmLists class (and Datasets class because the latter uses TfmLists objects). Pipeline also store a
split_idx as an attribute. Both Datasets and TfmLists generally have a train dataset (with a
split_idx=0), a valid dataset (with a
split_idx=1), and sometimes a test dataset (also with a
split_idx=1). When a Pipeline of Transform is applied to one of the 3 datasets, the Pipeline call each of its Transform objects by passing the
split_idx of the corresponding dataset that we are about to transform.
Therefore, if we are transforming a train dataset, the Pipeline passes
split_idx=0 to each of its Transform objects
_call() method. Similarly, for both valid dataset and test dataset, the Pipeline passes
split_idx=1 to each of its Transform objects
Now, back to our Transform
_call() method. The latter will compare the passed argument (from the Pipeline being the dataset
split_idx) to its
self.split_idx (self is the
Transform object), and decides either to ignore the call by returning the input
x without any change, or apply the transform through the
return self._do_call(getattr(self, fn), x, **kwargs) by following the rules mentioned here above.
Let’s check some Transform examples:
Transform examples in augment.py:
split_idx = None
mode,mode_mask,order,final_size = Image.BILINEAR,Image.NEAREST,1,None
"Resize image to `size` using `method`"
Resize has a
split_idx=None meaning that it will be applied to both the train and valid (test) datasets.
"A transform that before_call its state at each `__call__`"
do,nm,supports,split_idx = True,None,,0
RandTransform has a
split_idx=0 meaning that it will only be applied to the train dataset. Be aware, if the transform is not applied to a given item it is because the transform is applied with a given probability
p (meaning not all the train dataset items are transformed).
Test Time Augmentation tta() method
split_idx is also used in learner.py tta() method. Test time augmentation can significantly improves accuracy.
tta() combines predictions of several augmented images. To calculate the prediction of a given image, we follow the steps shown here below (assuming we are using the default n=4 value for the number of the augmented images):
1- First, we create 4 augmented images using the train dataset Pipeline Transforms. This is why the
dl.dataset.set_split_idx(0) is called in order to make sure the Pipeline objects passes
split_idx=0 to each Transform call. Each augmented image gets its prediction. The
aug_preds, representing predictions of the 4 images, is then reduced using either the max or the mean value,
2- Then, we calculate only one prediction (
preds) using the valid (test) dataset Pipeline of Transform. The use of
dl.dataset.set_split_idx(1) ensures to apply only the Pipeline Transforms that is set for the valid (test) dataset:
3- Finally, we combine
preds using either the max or a linear interpolation function.