Something I wrote quick:
import torch
from torch.utils import data
class SegmentedDataset(data.Dataset):
def __init__(self, images, labels, segments_per_image):
self.images = images
self.labels = labels
self.segments_per_image = segments_per_image
def get_segment(self, image, segment_ID):
# implement this function
# returns one segment of one image
pass
def __len__(self):
'Denotes the total number of samples'
return len(self.labels)*self.segments_per_image
def __getitem__(self, index):
'Generates one sample of data'
segment_ID = index%self.segments_per_image
image_ID = index//self.segments_per_image
image = self.images[image_ID]
# data and label
X = self.get_segment(image, segment_ID)
y = self.labels[image_ID]
return X, y
training_set = SegmentedDataset(train_images, train_labels)
training_generator = data.DataLoader(training_set, ...) # arguments like batch size
validation_set = SegmentedDataset(validation_images, validation_labels)
validation_generator = data.DataLoader(validation_set, ...) # arguments like batch size
ImageDataBunch(train_dl=training_generator, valid_dl=validation_generator, ...) other arguments
Hope it helps 