If interested, here is a custom dataset implementation for Quick Draw data. It is a bit too “hacky” way to implement dataset because it doesn’t have required attributes expected by fastai code. And, probably it will fail on execution of some of Learner
methods.
The code is not really related to your question, but maybe it could help you to figure out how to implement something similar. In your case, you don’t inherit from torch.Dataset
but from fastai base class as was mentioned above. And you don’t need to have a decorator if your dataset will be a part of the library.
def fastai_dataset(loss_func):
"""A class decorator to convert custom dataset into its fastai compatible version.
The decorator attaches required properties to the dataset to use it with
"""
def class_wrapper(dataset_cls):
def get_n_classes(self):
return len(self.classes)
def get_loss_func(self):
return loss_func
dataset_cls.c = property(get_n_classes)
dataset_cls.loss_func = property(get_loss_func)
return dataset_cls
return class_wrapper
@fastai_dataset(F.cross_entropy)
class QuickDraw(Dataset):
img_size = (256, 256)
def __init__(self, root: Path, train: bool=True, take_subset: bool=True,
subset_size: FloatOrInt=1000, bg_color='white',
stroke_color='black', lw=2, use_cache: bool=True):
subfolder = root/('train' if train else 'valid')
cache_file = subfolder.parent / 'cache' / f'{subfolder.name}.feather'
if use_cache and cache_file.exists():
log.info('Reading cached data from %s', cache_file)
# walk around to deal with pd.read_feather nthreads error
cats_df = feather.read_dataframe(cache_file)
else:
log.info('Parsing CSV files from %s...', subfolder)
subset_size = subset_size if take_subset else None
n_jobs = 1 if DEBUG else None
cats_df = read_parallel(subfolder.glob('*.csv'), subset_size, n_jobs)
if train:
cats_df = cats_df.sample(frac=1)
cats_df.reset_index(drop=True, inplace=True)
log.info('Done! Parsed files saved into cache file')
cache_file.parent.mkdir(parents=True, exist_ok=True)
cats_df.to_feather(cache_file)
targets = cats_df.word.values
classes = np.unique(targets)
class2idx = {v: k for k, v in enumerate(classes)}
labels = np.array([class2idx[c] for c in targets])
self.root = root
self.train = train
self.bg_color = bg_color
self.stroke_color = stroke_color
self.lw = lw
self.data = cats_df.points.values
self.classes = classes
self.class2idx = class2idx
self.labels = labels
self._cached_images = {}
def __len__(self):
return len(self.data)
def __getitem__(self, item):
points, target = self.data[item], self.labels[item]
image = self.to_image_tensor(points)
return image, target
def to_image_tensor(self, points):
img = to_pil_image(points, self.img_size, self.bg_color, self.stroke_color, self.lw)
return Image(to_tensor(img))
Then, I was able to use my custom code with the library as usual:
train_ds = QuickDraw(PREPARED, train=True)
valid_ds = QuickDraw(PREPARED, train=False)
bunch = ImageDataBunch.create(
train_ds, valid_ds, bs=bs, size=sz, ds_tfms=get_transforms())
bunch.normalize(imagenet_stats)
learn = create_cnn(bunch, models.resnet50, path='..')
learn.fit_one_cycle(n)