Hey @hushitz! Thank you for such a detailed analysis! Iām going to work on this over the weekend will report back with what I did
So here was my fix for Interpretation
:
preds, targs, decoded, losses = learn.get_preds(dl=dl, with_loss=True, with_decoded=True, act=None)
l, idxs = losses.topk(5, largest=True)
items = dl.dataset[idxs]
o = dl.new(items, bs=len(items))
a, _ = o.one_batch()
x,y,its = o._pre_show_batch(o.one_batch())
b_out = L(a) + L(decoded)
x1,y1,outs = o._pre_show_batch(b_out)
plot_top_losses(x1,y1[idxs],its,outs.itemgot(slice(1,None)),preds[idxs],l)
At least in the code that matters. What we do here instead is gather the indexs, index into the dataset, and build a new DataLoader
from this data, whos length is only that of our k (in this case 5). I plan on touching the widget next @hushitz
This version winds up only using 20mbs of memory total, vs the other that blasted through 12gb