Memory problem in torch dataloader when ImageDataBunch returns tensor

I have encountered a weird behavior when writing a dataloader from a custom ImageList.
I have a dataset with high dimensional images. I saved them to disk as torch tensors(in separate files). I overrode the open function of ImageList to use torch.load() instead of open_image and returned the tensor
For example

class ProblematicImageList(ImageList): 
    # fn is a path to a .pth file
    def open(fn):
        return torch.load(fn)

This has a memory leakage and with every batch the CPU ram usage keeps increasing and the training stops with the dataloader workers being killed by the OS.

But, if I simply do this, the error goes away.

class SolvedImageList(ImageList): 
    def open(fn):
        return Image(torch.load(fn))

This problem also exists for normal Image objects if I return tensors from open().

class TensorImageList(ImageList): 
    # fn is a path to an image. This class has a memory problem too
    def open(fn):
        return open_image(fn).data

I did the following experiment:

t = time.time()
a = torch.load('a.pth')
print(time.time()-t)
a = None
gc.collect()

For a tensor of size (34,482,512) the very first time this cell is run it takes about 0.1s while on subsequent runs it gets much faster and goes up to 0.0097s. I think torch is caching tensors in ram even after the reference is removed and garbage collector is called.

Any ideas on what is going on here? And also why does the problem get solved if I wrap the same tensor in an Image object?

It may be linked to the memory leak discovered in the LMDataLoader. We will profile this and try to find fixes when we’re done with development.

2 Likes