I can offer a worked example. I use a custom dataset to accomplish this, but everything else is vanilla fastai. I have a pandas dataframe with all of my data. I then split that into train, validation, and test data (for my current use-case, this is much easier for me than shuffling things around into different folders). I use a custom loader to make this really easy. For example, here is a classifier example derived from my scalar dataset
merged
is my pandas dataframe, which contains (at least) a PosixPath column named ‘file_path’, and a column identified by the trait
variable which has my categorical outcomes of interest.
class ImageCategoricalDataset(ImageDataset):
def __init__(self, df:DataFrame, path_column:str='file_path', dependent_variable:str=None):
# list(set(x)) basically takes the list of x, turns the values into the
# map keys, then turns it back into a list (now unique-ified thanks to
# the map key transform)
self.classes = list(set(df[dependent_variable]))
self.class2idx = {v:k for k,v in enumerate(self.classes)}
y = np.array([self.class2idx[o] for o in df[dependent_variable]], dtype=np.int64)
# The superclass does nice things for us like tensorizing the numpy
# input
super().__init__(df[path_column], y)
self.loss_func = F.cross_entropy
self.loss_fn = self.loss_func
def __getitem__(self, i:int):
return open_image(self.x[i]), self.y[i]
def __len__(self)->int:
return len(self.y)
I then set a random value on each value:
np.random.seed(31337)
merged['rand'] = np.random.uniform(low=0.0, high=1.0, size=(len(merged[trait],)))
I then generate datasets based on cutpoints of that random value:
dat_train = ImageCategoricalDataset(merged[merged['rand'] < 0.7], 'file_path', trait)
dat_valid = ImageCategoricalDataset(merged[(merged['rand'] >= 0.7) & (merged['rand'] < 0.9)], 'file_path', trait)
dat_test = ImageCategoricalDataset(merged[merged['rand'] >= 0.9], 'file_path', trait)
And finally I create my data bunch from these datasets and feed them into a learner:
data = ImageDataBunch.create(dat_train, dat_valid, dat_test,
ds_tfms=get_transforms(),
bs=128,
size=128)
learn = ConvLearner(data,
models.resnet50,
metrics=[accuracy, dice],
ps=0.5,
callback_fns=ShowGraph)
My background is in go, not python, so there are likely more efficient ways to accomplish this.