Top_losses for a specific type of error (classification)

I often have the need to analyse the largest losses of a classifier, but only for a specific type of error, i.e., one specific position of the confusion matrix. Is there any easy way to do this? For now what I am doing is to extend the ClassificationInterpretation class and override the methods top_losses and plot_top_losses with two extra arguments: predicted and actual to select an specific type of error.

# export
class ClassificationInterpretationExtended(ClassificationInterpretation):
    def top_losses(self, k=None, largest=True, predicted=None, actual=None):
        "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
        if predicted is None and actual is None:
            # Default behaviour
            return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
        else:
            # Subset losses by the conditions given in predicted and actual arguments
            cond_preds = (self.decoded == self.vocab.o2i[predicted]) if predicted else tensor(True)
            cond_actuals = (self.targs == self.vocab.o2i[actual]) if actual else tensor(True)
            idxs = (cond_preds & cond_actuals).nonzero().squeeze()
            loss_subset = self.losses[idxs].topk(ifnone(k, len(idxs)), largest=largest)
            # The indices in loss_subset are relative to the object `idxs`. We have to
            # return the aboluste idxs with respect to the `self` object.
            # TODO: It's returning a pair instead of a topk object
            return (loss_subset.values, idxs[loss_subset.indices])
        
    
    def plot_top_losses(self, k, largest=True, predicted=None, 
                        actual=None, **kwargs):
        losses,idx = self.top_losses(k, largest, predicted, actual)
        if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
        if isinstance(self.inputs[0], torch.Tensor): inps = tuple(o[idx] for o in self.inputs)
        else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
        b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
        x,y,its = self.dl._pre_show_batch(b, max_n=k)
        b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)