Update: I did take a look at the datasets and that’s indeed the way to go. I’ve managed to load my arrays by inheriting from FilesIndexArrayset and overriding the definition of the get_x method. Then I inherited from ImageClassifierData and overrode the from_pahts method, essentially copying the original definition but using my derived FilesIndexArrayset class. After that, you can import the new classifier class and use it instead of the usual ImageClassifierData.
Example of use:
from serializedArrayData import SerializedArrayClassifierData
data = SerializedArrayClassifierData.from_paths(PATH, bs, tfms)
Here is the code for anyone interested
#File serializedArrayData.py
import numpy as np
import os
from fastai.dataset import FilesIndexArrayDataset, ImageClassifierData, folder_source
class FilesSerializedArrays(FilesIndexArrayDataset):
def get_x(self, i): return np.load(os.path.join(self.path, self.fnames[i]))
class SerializedArrayClassifierData(ImageClassifierData):
@classmethod
def from_paths(cls, path, bs=64, tfms=(None,None), trn_name='train', val_name='valid', test_name=None, test_with_labels=False, num_workers=8):
""" Read in images and their labels given as sub-folder names
Arguments:
path: a root path of the data (used for storing trained models, precomputed values, etc)
bs: batch size
tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
trn_name: a name of the folder that contains training images.
val_name: a name of the folder that contains validation images.
test_name: a name of the folder that contains test images.
num_workers: number of workers
Returns:
ImageClassifierData
"""
assert not(tfms[0] is None or tfms[1] is None), "please provide transformations for your train and validation sets"
trn,val = [folder_source(path, o) for o in (trn_name, val_name)]
if test_name:
test = folder_source(path, test_name) if test_with_labels else read_dir(path, test_name)
else: test = None
datasets = cls.get_ds(FilesSerializedArrays, trn, val, tfms, path=path, test=test)
return cls(path, datasets, bs, num_workers, classes=trn[2])