split_idx
, or how to selectively apply a transform to train and valid (test) datasets?
First of all, I would like to credit @jeremy for sharing the 10 episodes of his fastai2 daily code walk-thrus, and @sgugger for answering questions on the forum as wall as @arora_aman, @akashpalrecha, and @init_27 for posting their code walk-thrus.
split_idx
allows a given Transform to be applied to a specific data subset: train, valid and test datasets. For example, it enables a specific image augmentation to be applied to only the train dataset and not to the valid dataset.
By setting your transform split_idx
flag, you can make your transform to be applied :
- to both train and valid (test) datasets if you set (leave) your transform split_idx to None
- or only to the train dataset (and not to the valid (test) datatset) if you set your transform split_idx to 0
- or only to the valid (test) dataset (and not to the train datatset) if you set your transform split_idx to 1
In which files split_idx can be found?
split_idx
can be found:
1- in fastai2.data.core.py
: split_idx
is used by both TfmLists, and Datasets classes (Datasets
uses TfmLists
objects) where:
○ the train dataset has a split_idx=0
,
○ the valid dataset has a split_idx=1
,
○ the test dataset also has a split_idx=1
,
○ There is also the set_split_idx() method that sets a split_idx
to a given dataset. That method is used in the “test time augmentattion” tta() method found in fast2.learner.py
2- in TfmDL: split_idx
is used by the before_iter() method in order to set split_idx
of each batch_tfms
Pipeline objects to the same split_idx
as the corresponding dataset (train and valid datasets)
3- in fastai2.vision.augment.py
: split_idx
is used by several Transform classes as shown further below (e.g. RandTransform, and Resize Transforms)
4- also in fast2.learner.py
: it is used by the “test time augmentattion” tta() method
How does it work?
A Transform has a split_idx attribute and defines the following _call ()
method:
def _call(self, fn, x, split_idx=None, **kwargs):
if split_idx!=self.split_idx and self.split_idx is not None: return x
return self._do_call(getattr(self, fn), x, **kwargs)
As you might notice, we pass a split_idx
argument to the _call()
method. That split_idx
argument is checked against the Transform self.split_idx
in the if
statement. The latter sets the behavior of the Transform as summarized here above.
We generally don’t explicitly call a Transform. A Pipeline which is a class that store a list of Transform objects is responsible of calling each one of its Transform objects _call()
method.
Pipeline are used in the TfmLists class (and Datasets class because the latter uses TfmLists objects). Pipeline also store a split_idx
as an attribute. Both Datasets and TfmLists generally have a train dataset (with a split_idx=0
), a valid dataset (with a split_idx=1
), and sometimes a test dataset (also with a split_idx=1
). When a Pipeline of Transform is applied to one of the 3 datasets, the Pipeline call each of its Transform objects by passing the split_idx
of the corresponding dataset that we are about to transform.
Therefore, if we are transforming a train dataset, the Pipeline passes split_idx=0
to each of its Transform objects _call()
method. Similarly, for both valid dataset and test dataset, the Pipeline passes split_idx=1
to each of its Transform objects _call()
method.
Now, back to our Transform _call()
method. The latter will compare the passed argument (from the Pipeline being the dataset split_idx
) to its self.split_idx
(self is the Transform
object), and decides either to ignore the call by returning the input x
without any change, or apply the transform through the return self._do_call(getattr(self, fn), x, **kwargs)
by following the rules mentioned here above.
Let’s check some Transform examples:
Transform examples in augment.py:
class Resize(RandTransform):
split_idx = None
mode,mode_mask,order,final_size = Image.BILINEAR,Image.NEAREST,1,None
"Resize image to `size` using `method`"
Resize has a split_idx=None
meaning that it will be applied to both the train and valid (test) datasets.
class RandTransform(Transform):
"A transform that before_call its state at each `__call__`"
do,nm,supports,split_idx = True,None,[],0
RandTransform has a split_idx=0
meaning that it will only be applied to the train dataset. Be aware, if the transform is not applied to a given item it is because the transform is applied with a given probability p
(meaning not all the train dataset items are transformed).
Test Time Augmentation tta() method
split_idx is also used in learner.py tta() method. Test time augmentation can significantly improves accuracy.
tta()
combines predictions of several augmented images. To calculate the prediction of a given image, we follow the steps shown here below (assuming we are using the default n=4 value for the number of the augmented images):
1- First, we create 4 augmented images using the train dataset Pipeline Transforms. This is why the dl.dataset.set_split_idx(0)
is called in order to make sure the Pipeline objects passes split_idx=0
to each Transform call. Each augmented image gets its prediction. The aug_preds
, representing predictions of the 4 images, is then reduced using either the max or the mean value,
2- Then, we calculate only one prediction (preds
) using the valid (test) dataset Pipeline of Transform. The use of dl.dataset.set_split_idx(1)
ensures to apply only the Pipeline Transforms that is set for the valid (test) dataset: split_idx=1
,
3- Finally, we combine aug_preds
and preds
using either the max or a linear interpolation function.