Hi @sgugger,
@FraPochetti and I have been working together this morning to review the proba-based metrics issue in fastai2 (RocAuc and APScore), and have jointly come up with a proposal we’d like to submit to you.
It manages all possibilities sklearn allows while keeping the API consistent with the rest of fastai2 metrics.
We have tested our proposal vs sklearn’s API using this gist and everything works well.
In sklearn there are 3 scenarios for roc_auc_score (each of them calculated slightly differently):
-
Binary:
-
targets
: shape = (n_samples, )
-
preds
: pass through softmax and then [:, -1], shape = (n_samples,)
-
Multiclass:
-
targets
: shape = (n_samples, )
-
preds
: pass through softmax, shape = (n_samples, n_classes)
-
multi_class
= ‘ovr’ or ‘ovo’ (1)
-
Multilabel:
-
targets
: shape = (n_samples, n_classes)
-
preds
: pass through sigmoid, shape = (n_samples, n_classes)
(1) ‘ovr’: average AUC of each class against the rest . 'ovo’ : average AUC of all possible pairwise combinations of classes.
sklearn’s average_precision_score implementation is restricted to binary or multilabel classification tasks. So it cannot be used in multiclass cases.
Here’s our proposal:
class AccumMetric(Metric):
"Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
def __init__(self, func, dim_argmax=None, sigmoid=False, softmax=False, proba=False, thresh=None, to_np=False, invert_arg=False,
flatten=True, **kwargs):
store_attr(self,'func,dim_argmax,sigmoid,softmax,proba,thresh,flatten')
self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs
def reset(self): self.targs,self.preds = [],[]
def accumulate(self, learn):
pred = learn.pred.argmax(dim=self.dim_argmax) if (self.dim_argmax and not self.proba) else learn.pred
if self.sigmoid: pred = torch.sigmoid(pred)
if self.thresh: pred = (pred >= self.thresh)
if self.softmax:
pred = F.softmax(pred, dim=-1)
if learn.dls.c == 2: pred = pred[:, -1]
targ = learn.y
pred,targ = to_detach(pred),to_detach(targ)
if self.flatten: pred,targ = flatten_check(pred,targ)
self.preds.append(pred)
self.targs.append(targ)
@property
def value(self):
if len(self.preds) == 0: return
preds,targs = torch.cat(self.preds),torch.cat(self.targs)
if self.to_np: preds,targs = preds.numpy(),targs.numpy()
return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)
@property
def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__
def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, softmax=False, proba=False, **kwargs):
"Convert `func` from sklearn.metrics to a fastai metric"
dim_argmax = axis if is_class and thresh is None else None
sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)
return AccumMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, softmax=softmax, proba=proba, thresh=thresh,
to_np=True, invert_arg=True, **kwargs)
def APScore(axis=-1, average='macro', pos_label=1, sample_weight=None):
"Average Precision for binary single-label classification problems"
return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, softmax=True, proba=True,
average=average, pos_label=pos_label, sample_weight=sample_weight)
def APScoreMulti(axis=-1, average='macro', pos_label=1, sample_weight=None):
"Average Precision for multi-label classification problems"
return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, sigmoid=True, proba=True,
average=average, pos_label=pos_label, sample_weight=sample_weight)
def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='raise', labels=None):
"Area Under the Receiver Operating Characteristic Curve for single-label classification problems"
"""use default multi_class ('raise') for binary-class, and 'ovr'(average AUC of each class against the rest)
or 'ovo' (average AUC of all possible pairwise combinations of classes) for multi-class tasks"""
return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, labels=labels)
def RocAucMulti(axis=-1, average='macro', sample_weight=None, max_fpr=None):
"Area Under the Receiver Operating Characteristic Curve for multi-label classification problems"
return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, sigmoid=True, proba=True,
average=average, sample_weight=sample_weight, max_fpr=max_fpr)
Please, let us know if we can help you in any way with this.