I often have the need to analyse the largest losses of a classifier, but only for a specific type of error, i.e., one specific position of the confusion matrix. Is there any easy way to do this? For now what I am doing is to extend the ClassificationInterpretation
class and override the methods top_losses
and plot_top_losses
with two extra arguments: predicted
and actual
to select an specific type of error.
# export
class ClassificationInterpretationExtended(ClassificationInterpretation):
def top_losses(self, k=None, largest=True, predicted=None, actual=None):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
if predicted is None and actual is None:
# Default behaviour
return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
else:
# Subset losses by the conditions given in predicted and actual arguments
cond_preds = (self.decoded == self.vocab.o2i[predicted]) if predicted else tensor(True)
cond_actuals = (self.targs == self.vocab.o2i[actual]) if actual else tensor(True)
idxs = (cond_preds & cond_actuals).nonzero().squeeze()
loss_subset = self.losses[idxs].topk(ifnone(k, len(idxs)), largest=largest)
# The indices in loss_subset are relative to the object `idxs`. We have to
# return the aboluste idxs with respect to the `self` object.
# TODO: It's returning a pair instead of a topk object
return (loss_subset.values, idxs[loss_subset.indices])
def plot_top_losses(self, k, largest=True, predicted=None,
actual=None, **kwargs):
losses,idx = self.top_losses(k, largest, predicted, actual)
if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
if isinstance(self.inputs[0], torch.Tensor): inps = tuple(o[idx] for o in self.inputs)
else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
x,y,its = self.dl._pre_show_batch(b, max_n=k)
b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
if its is not None:
plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses, **kwargs)