Train CNN on subset of image dataset to speed up feedback?

Running my first Iearn.fit_one_cycle(x) on a pathology dataset takes quite some time.
The whole dataset is about 6 Gb. It takes at least 10 minutes for 1 epoch.

To speed up the initial feedback, is there a method available to train on a subset of the whole dataset?

Just to speed up the process in the beginning? Or is this not a good approach?

Definitely the right approach because let you try different solution quickly and choose the best one. Jeremy does this since the first version of the course (v1/2016).

With a big dataset like the one you’re using I usually start assessing different models using a number of samples that let’s you run an epoch in maximum a minute (usualli around 5-10% of total samples).
Be sure to sample “properly” the original dataset to extract the smaller version (IE: for classification random shuffle the data before sampling and verify that the distribution is similar to the original one). According to my experience using the whole dataset usually improves the accouracy by a foctor around 10-20%.

I prefer to subsample the dataset manually, but AFAIK fast.ai has a built in method to do that:
https://docs.fast.ai/data_block.html#ItemList.use_partial_data

2 Likes

This is an example of manual dataset reduction usage that takes 10%:

import random
items_count = int(len(all_files) * .1)
files = random.sample(all_files,items_count)

I had the same question, so here’s a working sample using the “use_partial_data” fastai method:

src = (ImageItemList.from_folder(path).use_partial_data(0.2, seed=seed).split_by_folder().label_from_folder())

It will take a random 20% of the source data for the ImageItemList, set “seed” so the same random items are selected every time.

1 Like

Thanks. Finally I did use the following chain of methods, similar to you.

data = (ImageItemList.from_csv(PATH, folder='train', csv_file='labels.csv', cols={your_fname_col_name})
           .use_partial_data(sample_pct = .1, seed= 34)
           .random_split_by_pct(valid_pct=0.2, seed=34)
           .label_from_df(cols={your_label_cols_name})
           .transform(tfms, size = 96)
           .databunch(bs=64)).normalize(imagenet_stats)
2 Likes