I ended up taking a slightly different approach and building an ImageDataBunch from scratch, meaning building the training and validation datasets from the files in the directory, but limiting the number of classes and the number of images in each class before building the training set.
This is how it looks in code, in this case limiting to 7 classes and 600 images in each class for the training set:
my_classes = [
'amphitheater',
'art_gallery',
'art_studio',
'artists_loft',
'attic',
'auditorium',
'balcony-interior']
data = SampledImageDataBunch(path/'train',path/'val', my_classes, num_samples=600, ds_tfms=get_transforms(), size=224, bs=128)
The code to achieve this is below. It’s probably not very pythonic, but it works:
def training_set(train_dir, classes_to_train, num_samples=50, shuffle=True):
train_dirs = [dir for dir in train_dir.iterdir() if dir.name in classes_to_train]
train_files = []
train_labels = []
for dir in train_dirs:
fs = [f for f in dir.iterdir()]
print(dir.name + " has " + str(len(fs)) +" samples")
if shuffle:
files= random.sample(fs, num_samples)
else:
files = fs[:num_samples]
for f in files:
train_files.append(f)
train_labels.append(dir.name)
# verify that there are enough samples per class
ICD = ImageClassificationDataset(train_files,train_labels)
return ICD
def validation_set(val_dir, classes_to_train):
train_dirs = [dir for dir in val_dir.iterdir() if dir.name in classes_to_train]
train_files = []
train_labels = []
for dir in train_dirs:
files = [f for f in dir.iterdir()]
for f in files:
train_files.append(f)
train_labels.append(dir.name)
# verify that there are enough samples per class
ICD = ImageClassificationDataset(train_files,train_labels)
return ICD
# classes_to_train = [ 'fire_escape','lake-natural', 'cliff']
# ts = training_set(train_dir, classes_to_train, num_samples=200, shuffle=False)
# vs = validation_set(path/'val',classes_to_train)
# print(ts)
# print(vs)
def SampledImageDataBunch(train_dir, val_dir, classes_to_train, num_samples=200, shuffle=True, **kwargs):
ts =training_set (train_dir, classes_to_train, num_samples, shuffle,)
vs =validation_set(val_dir, classes_to_train)
data = ImageDataBunch.create(ts,vs, **kwargs)
return data