Learn.get_preds() memory inefficiency - quick fix

tl;dr
Under certain conditions, Learn.get_preds() can eat all your ram and kill your kernel, even on a large machine. This looks to be easily fixed for a 50% reduction by using an in-place sort, potentially a lot greater reductions depending on use-case.
/tl;dr

A problem that appears to have cropped up for many users over many versions of fastai (though I haven’t comprehensively checked) is Learn.get_preds() chewing up memory. This is the root-cause of ImageClassifierCleaner chewing up memory as well (see Out of memory when execute class ImagesCleaner and Learn.get_preds() running out of RAM at completion for examples).

The issue is that Learn.get_preds() appears to store all batches in memory, and then produce a sorted copy of all this using nested_reorder().

To replicate, edit your copy of fastai/learner.py (the auto-generated, “do not edit” version) and insert the following helper at the top:

def getMemStats(prefix=""):
    linux_filepath = "/proc/meminfo"
    try:
        meminfo = dict( (i.split()[0].rstrip(":"), int(i.split()[1])) for i in open(linux_filepath).readlines() )
        memused = (meminfo["MemTotal"]-meminfo["MemAvailable"])/(1024*1024*1024)
        print(prefix+"memused=%0.1fGB"%(memused))
    except Exception as e:
        print(prefix+"getMemStats() failed")
    return

Then in every second line in the definition of get_preds(), call getMemStats(“xxx”) - replacing xxx with a hint to where you are in the code. In particular, do this before and after the lines
self._do_epoch_validate(dl=dl)
and
if reorder and hasattr(dl, ‘get_idxs’): res = nested_reorder(res, tensor(idxs).argsort())

Save that, and now load 2020 notebook 01, and after the first active cell (where we train on cats-vs-dogs), add a cell with the following:

from fastai.vision.widgets import *
cleaner = ImageClassifierCleaner(learn)

Restart your kernel (to force reload of the edited fastai/learner.py), and run these two cells.

You should see memory use increase by ~3.3GB by the line _do_epoch_validate and the same again by the nested_reorder

For larger datasets than the Oxford Pets dataset, this can easily cause even the largest machine to run out of memory.

The easy win here is to look at the nested_reorder line : this appears to make a new sorted copy, rather than sorting in-place, thus using twice as much memory as necessary for this task. Performing an in-place sort will therefore save 50% of memory use instantly.

The step I am less sure of is how to reduce memory from the _do_epoch_validate : the question here is if the entire results of _do_epoch_validate need to be kept for all use-cases. My knowledge of the library and all use-cases is not good enough to confirm this.

However, for the case of ImageClassifierCleaner it is clear that this should not be necessary - it should be possible to loop through batches of get_preds and accumulate just the indices (as opposed to the actual images), targets and losses, which should take up far less memory per item. An even greater saving could be to keep only some maximum number of top losses per category, ejecting any items from memory that aren’t in at least one of these category top-losses lists. That would be much more complex to implement, so I assume the former would be the quick win.

Hopefully someone with more extensive experience editing the library will step in to confirm/refute my suggestions above.

4 Likes

A small followup: the root cause here is using get_preds(… with_input=True …), which then returns (as the name suggests) the input objects - i.e. gigabytes of your training/validation data, all in memory. The fact that it then copies it to create a sorted version is hugely inefficient, but it doesn’t actually look like there’s any need for this to be called in this manner for more than a single object.

A quick grep of the fast.ai source suggests that “with_input=True” is very uncommonly used functionality:

grep -r "\.get_preds\s*([^)]*with_input\s*=\s*True" /opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/interpret.py:        return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None))
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py:        inp,preds,_,dec_preds = self.get_preds(dl=dl,with_input=True, with_decoded=True)
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/tabular/learner.py:        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/vision/widgets.py:    inp,probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True)

And both the learner.py usages are getting single-item predictions, so not memory hungry.

Therefore the best way to correct this appears to be to modify ImageClassifierCleaner (discussed above) and Interpretation, which is used to display the confusion matrix e.g. ClassificationInterpretation.from_learner(learn).plot_confusion_matrix()

There is no way either of these functions needs to keep the actual object data in memory, they could both be modified to deal with just the indices of objects (and if/when needed, the dataloader used to load the object itself). This is a 100% fix of the memory issue, rather than the 50% fix of doing an in-place sort. However, surgery on the code for these classes is beyond me currently - hopefully someone with more experience on the codebase can step in.

1 Like

Hey @hushitz! Thank you for such a detailed analysis! I’m going to work on this over the weekend :slight_smile: will report back with what I did

So here was my fix for Interpretation:

preds, targs, decoded, losses = learn.get_preds(dl=dl, with_loss=True, with_decoded=True, act=None)
l, idxs = losses.topk(5, largest=True)
items = dl.dataset[idxs]
o = dl.new(items, bs=len(items))
a, _ = o.one_batch()
x,y,its = o._pre_show_batch(o.one_batch())
b_out = L(a) + L(decoded)
x1,y1,outs = o._pre_show_batch(b_out)
plot_top_losses(x1,y1[idxs],its,outs.itemgot(slice(1,None)),preds[idxs],l)

At least in the code that matters. What we do here instead is gather the indexs, index into the dataset, and build a new DataLoader from this data, whos length is only that of our k (in this case 5). I plan on touching the widget next @hushitz

This version winds up only using 20mbs of memory total, vs the other that blasted through 12gb

1 Like

is this updated in the library or we have to fix manually?

Considering I’m still building this, neither. Once I’ve placed a PR I’ll post I’ll the code here as well

2 Likes

Cool - delighted to see someone with experience dive in so rapidly!

What you propose is also the kind of fix I had in mind for the ImageClassifierCleaner, where we do need to show the actual input object, but could severely restrict the number of such objects being loaded as we only ever assess the worst say 100 or even 1000 objects of each class (5 might be a bit strict). I imagine the total solution is just to keep all the indices, and on-demand load the actual objects from them, then we don’t have to limit at all.

For the Interpretation class, I wonder if cutting at e.g. 5 is generally useful? I think that to calculate the confusion matrix, you’ll need to keep all the losses. It looks to me like you’re attacking the Interpretation.plot_top_losses() with your code above, which again can perfectly well only load a very limited number of objects.

I think the solution might be for the class itself, in Interpretation.from_learner(), to keep that list of losses and idxs, then in Interpretation.plot_top_losses() you can load your top 5 as you do above, and that way ClassificationInterpretation.plot_confusion_matrix() can have access to all the losses (without having to have access to the input images).

By the way, my earlier suggestion of in-place sorting for nested_reorder() might be a bad idea - I suspect nested_reorder() is specifically implemented that way so it’ll work on GPU, and doing things like in-place sorting and memory management would be hugely inefficient there I suspect. The correct solution is definitely to just keep the idxs and drop memory use from GBs to MBs.

I’m attacking both here actually :slight_smile: it’s just not written out in the format the Interpretation class is made

How are you getting on with this? Anything I can do to help as a fastai-framework-hacking-newbie?

I’ve been AFK for a bit due to moving, however someone else and I are adding an enhancement on the interp object, which get’s us part of the way there towards what I was planning. You can see the PR here:

I should be able to migrate towards this over the weekend if all goes well

Specifically the __getitem__. That’s the main function I’ll be modifying to do our search for the dataset and apply things, etc

Aaah, very nice.

I think I see a further potential issue with the Interpretation class : if used for a segmentation task, keeping the predictions and targets all in memory will be a similarly large problem.

This argues strongly for the approach it looks like you’re taking, of making a get_item / show_at, so we store in memory only the index and the resulting loss. Then we would simply run get_preds() when we want to show an actual value, using a batch size of 1 (or possibly a small list - using something like Learner.predict() does calling Datasets.test_set() - which can take a list of items). It’s pointless having the Interpretation class store all the inputs, predictions and targets in memory for the whole dataset, if we only ever look at the top_losses anyway (but see below RE ClassificationInterpretation subclass).

The only possible snag with this approach is that we need to assume that generation of predictions/losses is deterministic, which means not using augmentation transforms. I’m not sure if people have used it this way, but the current behavior (allowing use of augmentations) might be useful in identifying when certain augmentations cause issues - but I suspect that’s a much less common usecase than just seeing which non-transformed items have problems.

All of this also applies to the ClassificationInterpretation subclass - however the generation of a confusion matrix requires that we retain the predictions and targets, and that seems to be commonly used functionality. Therefore I suspect the correct thing to do is to have an option for the Interpretation class that determines if predictions and targets are retained or not (for classification cases, this is very little memory, so won’t be an issue - it’s only segmentation cases where that would be a problem).

The solution for ImageClassifierCleaner is probably a little different: it’s the _get_iw_info function that needs modification here, which returns a list of some number of (input, target, loss) - without significant surgery to the widget code, _get_iw_info could easily be updated to call get_preds() once storing predictions, targets and losses, but not the inputs. Then the ImagesCleaner._open_thumb() appears to accept filenames instead of the actual data, so we just need to keep the filename rather than the actual image in our list that _get_iw_info returns, and I think it should just work.

Not sure entirely what you mean here. ClassificationInterpretation only always applies the validation transforms, so all you ever get is a center crop + Normalizing

I’ll likely do something like low_memory or something similar, where we just store the losses and their index’s so we can lazily grab the proper input/prediction as needed

After I get done with the Interp stuff, I’ll try and take a look, though it sounds like your much more familiar with its code, so you may have a better shot getting it faster once I’ve finished the adjustments to Interp :slight_smile:

Aaah, in which case I’m hypothesizing a further usecase that doesn’t yet exist :wink: (looking at top losses to work out if you’re applying transforms too harshly, e.g. cropping too small regions)

I’m familiar with that tiny piece of code from looking at it - it’s making mods to the framework as a whole and potentially breaking dependencies I’m dubious about given lack of experience. Happy to take a shot at the ImageClassifierCleaner though as my suggested solution is pretty self-contained.

1 Like

Small correction, this only happens if you choose the training DataLoader to utilize in interp (but normally/by default we just do the validation DL). It’s a very minor use, so not too concerned about it with what I’m doing.

Ok, it looks like the modification for ImageClassifierCleaner is even easier than I expected, which is to say that the current implementation is even more wasteful than I expexted.

Not only do we not need to keep all the input images in memory, but we don’t need to use the CPU time to do the dl.decode_batch either, because all that’s really doing is trying to decode the vocab of the dataloader - which for a classifier (which we know this is), this is just a simple index lookup.

Overall a 2-line change, resuling in many GB less resource use and faster execution.

It looks like it works, but would appreciate someone who is more familiar with the framework than me checking the 2 lines of code changed, and testing it properly. Memory use before/after in my test as described at the top of this first shows ~5MB increase in memory, as opposed to ~3.3GB increase before the change.

In fastai/vision/widgets.py replace the _get_iw_info function with the below (actually, it’s only the 2nd and 3rd lines that change).

def _get_iw_info(learn, ds_idx=0):
    dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
    probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True)
    targs = [dl.vocab[t] for t in targs]
    return L([dl.dataset.items,targs,losses]).zip()

Oh, wow, that fix maps almost perfectly onto the Interpretation class as well, completely avoiding all the faff we discussed above.

Just open fastai/interpret.py and replace the from_learner method with the following, again a 2-line change:

def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
    "Construct interpretation object from a learner"
    if dl is None: dl = learn.dls[ds_idx]
    dl = dl.new(shuffle=False, drop_last=False)
    return cls(dl, dl.dataset.items, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=$

This avoids storing the inputs in memory, and (for image dataloaders) just passes around filenames rather than data.

One niggling concern : do I need to do anything when creating those new dataloaders to ensure no transforms are applied?

Ok, bug: interp.plot_top_losses expects a list of image tensors for the input, not filenames, so we need to edit that once we have extracted our k top losses.

Apologies, I neglected to post the fix to interp.plot_top_losses.

Simply replace this line:

else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))

with this line:

else: inps = (first(to_cpu(self.dl.after_batch(to_device(first(self.dl.create_batches(idx)))))),)

I suspect that line is inefficient (can it possibly need that many calls to first and to_device/to_cpu?) but it’s still WAY more efficient than keeping every training/validation item decoded in memory.

@muellerzr Can you check my two 2-line fixes above, and this one-liner, and see if you agree that these are appropriate fixes?

Just putting just the actual fixes in one place to make code review easier. The commented lines are the ones being replaced.

# new Interpretation.from_learner in fastai/interpret.py
def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
    "Construct interpretation object from a learner"
    if dl is None: dl = learn.dls[ds_idx]
    #return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None))
    dl = dl.new(shuffle=False, drop_last=False)
    return cls(dl, dl.dataset.items, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

# new Interpretation.plot_top_losses in fastai/interpret.py
def plot_top_losses(self, k, largest=True, **kwargs):
    losses,idx = self.top_losses(k, largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = (first(to_cpu(self.dl.after_batch(to_device(first(self.dl.create_batches(idx)))))),)
    #else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    if its is not None:
        plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)

# new _get_iw_info in fastai/vision/widgets.py
def _get_iw_info(learn, ds_idx=0):
    dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
    #inp,probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True)
    #inp,targs = L(zip(*dl.decode_batch((inp,targs), max_n=9999)))
    probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True)
    targs = [dl.vocab[t] for t in targs]
    return L([dl.dataset.items,targs,losses]).zip()

Thanks!!! That helped put the final pieces together. So with the addition of this article, https://walkwithfastai.com/interp, we can simplify the code quite a bit.

That article focused on applying a __getitem__ to the Interpretation object, and my real goal was to eventually give us a __getitem__ that can give us inputs without needing much ram. (Jeremy was the one suggesting a __getitem__) As a result we have the following interpretation class (with the information needed):

    def __getitem__(self, idxs):
        "Get the the inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        attrs = 'preds,targs,decoded,losses'
        res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
        inps = [self.dl.do_item(o)[:self.dl.n_inp] for o in idxs]
        inps = self.dl.after_batch(to_device(self.dl.do_batch(inps), self.dl.device))
        return inps + res

    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx]
        return cls(dl, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

    def plot_top_losses(self, k, largest=True, **kwargs):
        losses,idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        if not isinstance(inps, tuple): inps = (inps,)
        b = inps + (targs,)
        x,y,its = self.dl._pre_show_batch(b, max_n=k)
        b_out = inps + (decoded,)
        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses,  **kwargs)

Not only does it clean up the code for plot_top_losses, but now we also don’t run into out of memory issues :slight_smile:

Full implementation below:

class Interpretation():
    "Interpretation base class, can be inherited for task specific Interpretation classes"
    def __init__(self, dl, preds, targs, decoded, losses): store_attr()

    def __getitem__(self, idxs):
        "Get the the inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        attrs = 'preds,targs,decoded,losses'
        res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
        inps = [self.dl.do_item(o)[:self.dl.n_inp] for o in idxs]
        inps = self.dl.after_batch(to_device(self.dl.do_batch(inps), self.dl.device))
        return inps + res

    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx]
        return cls(dl, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

    def top_losses(self, k=None, largest=True):
        "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
        return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)

    def plot_top_losses(self, k, largest=True, **kwargs):
        losses,idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        if not isinstance(inps, tuple): inps = (inps,)
        b = inps + (targs,)
        x,y,its = self.dl._pre_show_batch(b, max_n=k)
        b_out = inps + (decoded,)
        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses,  **kwargs)

And also a gist showing this running on 4x of PETs (something unreasonable that would crash): https://gist.github.com/muellerzr/497daf4a27239cf691e140b5710f2901

3 Likes