Thank you for sharing and I modified a bit like:
class AUC(Metric): "AUC score for each class in single-label multi-class classifications." def __init__(self, main_class=0, classes = noop): super().__init__() self.main_class = main_class self.classes = classes def reset(self): self.targs, self.preds = [],[] def accumulate(self, learn): pred = learn.pred targ = learn.y pred, targ = to_detach(pred), to_detach(targ) self.preds.append(pred) self.targs.append(targ) @property def value(self): if len(self.preds) == 0: return preds = torch.cat(self.preds) targs = torch.cat(self.targs) idx = (targs==self.main_class) targs = torch.zeros(targs.size()) targs[idx] = 1 preds = F.softmax(preds, dim=1)[:, self.main_class] return skm.roc_auc_score(targs.cpu().numpy(), preds.cpu().numpy()) @property def name(self): return f'{self.classes[self.main_class]} AUC'
and use it
metrics=[accuracy] + [AUC(c, databunch.vocab) for c in range(databunch.c)] def get_learner2(): learn = cnn_learner(databunch, xresnet50, opt_func=opt_func, metrics=metrics) return learn.to_fp16()