Load the whole dataset in RAM

In order to avoid idle cycles for the GPU and ensure faster availability of the minibatches, I’d like to load in RAM the entire datasets I work upon, for I have observed that none of the datasets I had in my hands so far exceeded in size the overall capacity of my main memory.

I’m particularly interested in vision and tabular datasets

The topic has been discussed before, but the threads I found are either quite old (2017) or a bit vague, or both.

2 Likes

You should customize a little bit an ImageList to do that (I’m guessing it’s for images since in the other applications, the whole dataset is loaded in RAM).

1 Like

I’ll try and delve into ImageList. If I manage to succeed, I’ll eventually sumbit a PR with a new optional arg for doing that. If not, I’ll call for help.
On the other hand, I ignored (due to my own negligence) that for other applications that was the defaut behaviour. Thanks!

2 Likes

Has anyone figured this out? This seems like loading the entire dataset in RAM would reduce a major bottleneck for small image datasets.

1 Like

Yes you could use this:

Will lazy load the images into a dict and speeds up epoch time.

class MemoryImageList(ImageList):
    _map = {}
    def open(self, i):
        item = self._map.get(str(i))
        if isinstance(item, Image):
            return item
        item = super().open(i)
        self._map[str(i)] = item
        return item
data = MemoryImageList.from_folder(untar_data(URLs.MNIST)/'training').split_by_rand_pct(.2, seed=1).label_from_folder().databunch(bs=128).normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
learn.fit_one_cycle(3)

Demo Notebook

4 Likes

Here’s a version that loads or purges the full dataset on command.

class ImageListInMemory(ImageList):
    _bunch,_square_show,_square_show_res = ImageDataBunch,True,True
    def __init__(self, *args, convert_mode='RGB', after_open=None, **kwargs):
        super().__init__(*args, **kwargs, convert_mode=convert_mode, after_open=after_open)
        self.loaded = False

    def load_data(self):
        if not self.loaded:
            self.images = [open_image(item, convert_mode=self.convert_mode, after_open=self.after_open)
                           for item in self.items]
            self.loaded = True

    def purge(self):
        if self.loaded:
            self.images = None
            gc.collect()
            self.loaded = False

    def get(self, i):
        if self.loaded:
            res = self.images[i]
            self.sizes[i] = res.size
            return res

        else:
            res = super().get(i)
            self.sizes[i] = res.size
            return res  

I’m actually trying to figure this out for a GAN itemlist. The difference is for a GAN, your image items are set as labels. Does anyone know what function loads labels? I’d want to alter that to pull from a pre-loaded list of images.

2 Likes

You set the label class with label_cls in your label call of the data block API.