Stratified labels sampling

I can offer a worked example. I use a custom dataset to accomplish this, but everything else is vanilla fastai. I have a pandas dataframe with all of my data. I then split that into train, validation, and test data (for my current use-case, this is much easier for me than shuffling things around into different folders). I use a custom loader to make this really easy. For example, here is a classifier example derived from my scalar dataset

merged is my pandas dataframe, which contains (at least) a PosixPath column named ‘file_path’, and a column identified by the trait variable which has my categorical outcomes of interest.

class ImageCategoricalDataset(ImageDataset):
    def __init__(self, df:DataFrame, path_column:str='file_path', dependent_variable:str=None):
        # list(set(x)) basically takes the list of x, turns the values into the
        # map keys, then turns it back into a list (now unique-ified thanks to
        # the map key transform)
        self.classes = list(set(df[dependent_variable]))
        self.class2idx = {v:k for k,v in enumerate(self.classes)}
        y = np.array([self.class2idx[o] for o in df[dependent_variable]], dtype=np.int64)

        # The superclass does nice things for us like tensorizing the numpy
        # input
        super().__init__(df[path_column], y)
        
        self.loss_func = F.cross_entropy
        self.loss_fn = self.loss_func
        
    def __getitem__(self, i:int):
        return open_image(self.x[i]), self.y[i]

    def __len__(self)->int:
        return len(self.y)

I then set a random value on each value:

np.random.seed(31337)
merged['rand'] = np.random.uniform(low=0.0, high=1.0, size=(len(merged[trait],)))

I then generate datasets based on cutpoints of that random value:

dat_train = ImageCategoricalDataset(merged[merged['rand'] < 0.7], 'file_path', trait)
dat_valid = ImageCategoricalDataset(merged[(merged['rand'] >= 0.7) & (merged['rand'] < 0.9)], 'file_path', trait)
dat_test = ImageCategoricalDataset(merged[merged['rand'] >= 0.9], 'file_path', trait)

And finally I create my data bunch from these datasets and feed them into a learner:

data = ImageDataBunch.create(dat_train, dat_valid, dat_test,
    ds_tfms=get_transforms(),
    bs=128,
    size=128)

learn = ConvLearner(data, 
                    models.resnet50, 
                    metrics=[accuracy, dice], 
                    ps=0.5,
                    callback_fns=ShowGraph)

My background is in go, not python, so there are likely more efficient ways to accomplish this.

3 Likes