Learn.get_preds() memory inefficiency - quick fix

Small correction, this only happens if you choose the training DataLoader to utilize in interp (but normally/by default we just do the validation DL). It’s a very minor use, so not too concerned about it with what I’m doing.

Ok, it looks like the modification for ImageClassifierCleaner is even easier than I expected, which is to say that the current implementation is even more wasteful than I expexted.

Not only do we not need to keep all the input images in memory, but we don’t need to use the CPU time to do the dl.decode_batch either, because all that’s really doing is trying to decode the vocab of the dataloader - which for a classifier (which we know this is), this is just a simple index lookup.

Overall a 2-line change, resuling in many GB less resource use and faster execution.

It looks like it works, but would appreciate someone who is more familiar with the framework than me checking the 2 lines of code changed, and testing it properly. Memory use before/after in my test as described at the top of this first shows ~5MB increase in memory, as opposed to ~3.3GB increase before the change.

In fastai/vision/widgets.py replace the _get_iw_info function with the below (actually, it’s only the 2nd and 3rd lines that change).

def _get_iw_info(learn, ds_idx=0):
    dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
    probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True)
    targs = [dl.vocab[t] for t in targs]
    return L([dl.dataset.items,targs,losses]).zip()

Oh, wow, that fix maps almost perfectly onto the Interpretation class as well, completely avoiding all the faff we discussed above.

Just open fastai/interpret.py and replace the from_learner method with the following, again a 2-line change:

def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
    "Construct interpretation object from a learner"
    if dl is None: dl = learn.dls[ds_idx]
    dl = dl.new(shuffle=False, drop_last=False)
    return cls(dl, dl.dataset.items, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=$

This avoids storing the inputs in memory, and (for image dataloaders) just passes around filenames rather than data.

One niggling concern : do I need to do anything when creating those new dataloaders to ensure no transforms are applied?

Ok, bug: interp.plot_top_losses expects a list of image tensors for the input, not filenames, so we need to edit that once we have extracted our k top losses.

Apologies, I neglected to post the fix to interp.plot_top_losses.

Simply replace this line:

else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))

with this line:

else: inps = (first(to_cpu(self.dl.after_batch(to_device(first(self.dl.create_batches(idx)))))),)

I suspect that line is inefficient (can it possibly need that many calls to first and to_device/to_cpu?) but it’s still WAY more efficient than keeping every training/validation item decoded in memory.

@muellerzr Can you check my two 2-line fixes above, and this one-liner, and see if you agree that these are appropriate fixes?

Just putting just the actual fixes in one place to make code review easier. The commented lines are the ones being replaced.

# new Interpretation.from_learner in fastai/interpret.py
def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
    "Construct interpretation object from a learner"
    if dl is None: dl = learn.dls[ds_idx]
    #return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None))
    dl = dl.new(shuffle=False, drop_last=False)
    return cls(dl, dl.dataset.items, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

# new Interpretation.plot_top_losses in fastai/interpret.py
def plot_top_losses(self, k, largest=True, **kwargs):
    losses,idx = self.top_losses(k, largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = (first(to_cpu(self.dl.after_batch(to_device(first(self.dl.create_batches(idx)))))),)
    #else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    if its is not None:
        plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)

# new _get_iw_info in fastai/vision/widgets.py
def _get_iw_info(learn, ds_idx=0):
    dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
    #inp,probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True)
    #inp,targs = L(zip(*dl.decode_batch((inp,targs), max_n=9999)))
    probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True)
    targs = [dl.vocab[t] for t in targs]
    return L([dl.dataset.items,targs,losses]).zip()

Thanks!!! That helped put the final pieces together. So with the addition of this article, https://walkwithfastai.com/interp, we can simplify the code quite a bit.

That article focused on applying a __getitem__ to the Interpretation object, and my real goal was to eventually give us a __getitem__ that can give us inputs without needing much ram. (Jeremy was the one suggesting a __getitem__) As a result we have the following interpretation class (with the information needed):

    def __getitem__(self, idxs):
        "Get the the inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        attrs = 'preds,targs,decoded,losses'
        res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
        inps = [self.dl.do_item(o)[:self.dl.n_inp] for o in idxs]
        inps = self.dl.after_batch(to_device(self.dl.do_batch(inps), self.dl.device))
        return inps + res

    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx]
        return cls(dl, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

    def plot_top_losses(self, k, largest=True, **kwargs):
        losses,idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        if not isinstance(inps, tuple): inps = (inps,)
        b = inps + (targs,)
        x,y,its = self.dl._pre_show_batch(b, max_n=k)
        b_out = inps + (decoded,)
        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses,  **kwargs)

Not only does it clean up the code for plot_top_losses, but now we also don’t run into out of memory issues :slight_smile:

Full implementation below:

class Interpretation():
    "Interpretation base class, can be inherited for task specific Interpretation classes"
    def __init__(self, dl, preds, targs, decoded, losses): store_attr()

    def __getitem__(self, idxs):
        "Get the the inputs, preds, targs, decoded outputs, and losses at `idxs`"
        if isinstance(idxs, Tensor): idxs = idxs.tolist()
        if not is_listy(idxs): idxs = [idxs]
        attrs = 'preds,targs,decoded,losses'
        res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
        inps = [self.dl.do_item(o)[:self.dl.n_inp] for o in idxs]
        inps = self.dl.after_batch(to_device(self.dl.do_batch(inps), self.dl.device))
        return inps + res

    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        "Construct interpretation object from a learner"
        if dl is None: dl = learn.dls[ds_idx]
        return cls(dl, *learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True, act=None))

    def top_losses(self, k=None, largest=True):
        "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
        return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)

    def plot_top_losses(self, k, largest=True, **kwargs):
        losses,idx = self.top_losses(k, largest)
        inps, preds, targs, decoded, _ = self[idx]
        if not isinstance(inps, tuple): inps = (inps,)
        b = inps + (targs,)
        x,y,its = self.dl._pre_show_batch(b, max_n=k)
        b_out = inps + (decoded,)
        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses,  **kwargs)

And also a gist showing this running on 4x of PETs (something unreasonable that would crash): https://gist.github.com/muellerzr/497daf4a27239cf691e140b5710f2901

3 Likes

Much nicer than my hacky fix - I was trying for minimal changes as I don’t yet grok the framework.

However I’m pretty sure my fix to fastai/vision/widgets.py is safe - can you confirm?

Happy for you to put it all into one PR at once if you’re happy with it.

Ideally we’d put this in 2 PR’s, one to Interp notebook and py and one to the widgets notebook and py. Jeremy likes incremental PR’s :slight_smile: Haven’t checked on the ImageCleaner yet to see if that can be adjusted at all, should be able to review that in the next day or two

Also I’ve only tested this on images. We need to test this out on all supported applications to make sure it works in all scenarios

This does break on tabular, so we need to rethink a few things

You mean the new Interpretation class you’ve built, or the small-fixes approach I tried?

I’ve only ever done the video walk-throughs on Tabular, so can’t possibly claim any expertise there - would not be at all surprised if my fixes broke :wink:

The one I built. it should do something similar to what you did, so I’ll need to sit down and figure out why that’s not quite working

Ok - if you have a simple test-case that breaks yours, send it over and I’ll see if it breaks on mine as well.

For the testing I’m using this notebook: https://github.com/fastai/fastai/blob/master/nbs/examples/app_examples.ipynb

And simply doing:

dls = dls
learn = application_learner()
interp = Interpretation.from_learner(learn)
interp[0] # or interp.plot_top_losses(1)

Ok, my fixes above seem to work for everything for which an Interpretation class is implemented (e.g. plot_top_losses not implemented for segmentation or image regression). Tabular definitely works though.

1 Like

I should confirm that for testing Tabular and Text, I used
interp = ClassificationInterpretation.from_learner(learn)
interp.most_confused()

@muellerzr How is your improved Interpretation class looking?
I made a PR for ImageClassifierCleaner as you suggested to keep them separate, and it is a fix that shouldn’t have any deeper consequences. I’ll leave the Interpretation class fixes to you as you’re clearly doing more major surgery.

1 Like

Hello @muellerzr .
The memory inefficiency is still exist on the fastai version released three days ago (2.5.0). When I try to predict a dataset with 500K samples using tabular model, the 128G system RAM gets full after 74% of prediction and the computer freezes. The prediction works using 200K samples. I noticed that the class Interpretation on the current version of fastai is not using the getitem() strategy proposed here…
When I copy the class interpretation proposed here to the library, and run the interp = ClassificationInterpretation.from_learner(learn), I got the following error:
“TypeError: init() missing 1 required positional argument: ‘losses’”
Hope you can help address this issue…
BR

Yes, it is not in fastai, as the method proposed does not work in every situation.

Re your losses, please post the version of fastai and fastcore you tried with, as I was not able to recreate that with:

from fastai.vision.all import *

set_seed(99, True)
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))

learn = cnn_learner(dls, resnet34, metrics=error_rate)
interp = ClassificationInterpretation.from_learner(learn)
1 Like

Appreciate your quick feedback. Here are the version used:
fastai: 2.5.0
fastcore: 1.3.25
BR