ImageDataBunch from 500 megapixel images as tiles

Something I wrote quick:

import torch
from torch.utils import data

class SegmentedDataset(data.Dataset):
    def __init__(self, images, labels, segments_per_image):
        self.images = images
        self.labels = labels
        self.segments_per_image = segments_per_image

    def get_segment(self, image, segment_ID):
        # implement this function
        # returns one segment of one image
        pass

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.labels)*self.segments_per_image

    def __getitem__(self, index):
        'Generates one sample of data'
        segment_ID = index%self.segments_per_image
        image_ID = index//self.segments_per_image

        image = self.images[image_ID]
 
        # data and label
        X = self.get_segment(image, segment_ID)
        y = self.labels[image_ID]

        return X, y

training_set = SegmentedDataset(train_images, train_labels)
training_generator = data.DataLoader(training_set, ...) # arguments like batch size

validation_set = SegmentedDataset(validation_images, validation_labels)
validation_generator = data.DataLoader(validation_set, ...) # arguments like batch size

ImageDataBunch(train_dl=training_generator, valid_dl=validation_generator, ...) other arguments

Hope it helps :slight_smile:

2 Likes