Get a subset of the available training/validation set using high level DataBlocks


I have N training data in a folder and I want x% of that data only, in my dataloaders’ training set. One use case is raising learning curves and another (that I actually need) is to re-implement self supervized learning from Epoching’s Blog using fastai2. At one point it is shown how a SSL trained classifier can learn digits with only 180 labeled samples. So I need to have 180 training samples and 1000 validation samples in a DataLoaders object.

I have found 2 solutions, but I would like to know the most fastai2-ic. That is, the one that will allow me to build upon, with data augmentation and what not.
Motivation: I need all the “magic” defaults for image loading and converting to normalized float tensor, label categorization, batching, etc. so I would rather stick to the high level API (DataBlocks).

For the next part, data_path points to MNIST dataset downloaded with fastai2.

Solution 1. Write custom data splitter:
I need to take a percent of the training set while keeping the validation set intact. The first place where the train/validation sets are available as distinct entities are after the splitting. So I created my custom splitter to handle both actual splitting and train data subsetting:

def custom_splitter(train_name, valid_name, train_pct):
    def fn(name_list):
        train_idx, valid_idx = GrandparentSplitter(train_name=train_name, valid_name=valid_name)(name_list)
        train_len = int(len(train_idx) * train_pct)
        return train_idx[0:train_len], valid_idx
    return fn
mn_db = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
                     get_y = parent_label,
                     splitter=custom_splitter(train_name='training', valid_name='testing', train_pct=0.003))
mnist_small_dls = mn_db.dataloaders(data_path)
print(f"Training dataset: {len(mnist_small_dls.train_ds)} Validation dataset: {len(mnist_small_dls.valid_ds)}")

The custom splitter calls the GrandparentSplitter, shuffles the training data and then slices the first train_pct samples. Returns two lists back, as expected by the DataBlock.

Solution 2. Use a trick intended for testing on new samples

As pointed out in this thread, one can create a new test dataloader from the validation dataloader:

mnist_block = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
                     get_y = parent_label,
                     splitter=GrandparentSplitter(train_name='training', valid_name='testing'))
mnist_dls = mnist_block.dataloaders(source=data_path)
selected_items = np.random.choice(mnist_dls.train_ds.items, 180, replace=False)
# Create a new dataloader and replace the existing train dataloader 
mnist_dls.train = mnist_dls.test_dl(selected_items, with_labels=True)

print(f"Training dataset: {len(mnist_dls.train_ds)} Validation dataset: {len(mnist_dls.valid_ds)}")

Also, writting my own get_items as suggested by sgugger in the same thread, won’t do the trick, I do want all my validation set intact.


IMHO the 1st way is more fastai2 because I use the callbacks. The 2nd approach is a bit hacky. Both require shuffling, putting shuffle=True in mnist_block.dataloaders throws an error.
Probably with middle level API, it would be more natural but there, afaik, you have to write all the transforms yourself (eg loading, converting to float32, normalization, categorization, etc) So no magic. Also, for the 2nd solution, will it carry on the eventual data augmetation?

Are there other ways, more towards the fastai2 philosophy?

Thank you!

1 Like

@visoft, I agree with you. I would choose the 1st solution.

Warning concerning the 2nd solution
I would like to point out that for the 2nd solution, one has to pay attention to the fact that the mnist_dls.train created using the test_dl() method will have its split_idx = 1, and that will trigger some surprises in some cases.

Indeed, if your are using transform, say FlipItem(), which inherits from RandTransform (the latter has split_idx =0) and try to apply that transform to your mnist_dls.train dataset, it will be just ignored because of mnist_dls.train split_idx != FlipItem() split_ix. The reason is the Transform __call__() function check if the Transform split_idx is equal or not to the one passed to the __call() method . I they are equal, the transform is applied, and if not the call is ignored, and the Transform will return the original input (x) unchanged: meaning do nothing. Therefore, for the FlipItem(), the transform won’t be applied to mnist_dls.train dataset.

Now, we may ask our self how can we apply a transform to both train and valid dataset. The answer is either setting split_idx = None in our corresponding transform or just not declare it at all (which means split_idx is set automatically to None to our transform). The latter convention seems to be the one chosen in fastai.

Here is the kicker, if your transform has split_idx = None, your transform will be applied to your mnist_dls.train (even if the latter has split_idx = 1) and you may think it’s business as usual. But, if your Transform has a split_idx = 0, you may end up wondering why your transform is not applied, and the explanation, here above, reveals the why.

NB: When we create a dataset, the train dataset automatically gets a split_ix=0, and the valid dataset gets a split_ix=1

1 Like

Nice! Now I know something more! In the documentation the split_idx is barely mentioned. I saw it only because I’ve read your post first. Now I dug a little deeper in the code for Transform and ItemTransform. Thank you!

1 Like

You are very welcome. you might check out the Fastai v2 Recipes (Tips and Tricks) - Wiki thread I started a couple of weeks ago, if you haven’t done that yet. You might find some interesting stuff during your code exploration.

I will be adding new tips on a regular basis from now and on. By the way, that thread is a Wiki, and everybody is welcome to contribute by adding their own tips.

1 Like

Dont know how much related and havent used them, but it looks that Data block split_subsets - define train and val size independently could work

I guess the split_subsets is outdated and no longer works.