Discrepancy with proba-based metrics between fastai2 and sklearn

Woops, fixed now.

Thanks Sylvain for being so fast!! :smiley:

I’ve tested it using the same gist again, but I get the following error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-4-d2f47d35092f> in <module>()
      1 learn = cnn_learner(dls, resnet18, pretrained=False, metrics=[accuracy, APScore(), RocAuc()])
----> 2 learn.fit_one_cycle(5, 0.1)

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    199                         self.epoch=epoch;          self('begin_epoch')
    200                         self._do_epoch_train()
--> 201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')
    203                     finally:                       self('after_epoch')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    181         try:
    182             self.dl = dl;                                    self('begin_validate')
--> 183             with torch.no_grad(): self.all_batches()
    184         except CancelValidException:                         self('after_cancel_validate')
    185         finally:                                             self('after_validate')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    154 
    155     def one_batch(self, i, b):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in one_batch(self, i, b)
    165             self.opt.zero_grad()
    166         except CancelBatchException:                         self('after_cancel_batch')
--> 167         finally:                                             self('after_batch')
    168 
    169     def _do_begin_fit(self, n_epoch):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    133 
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    374              else f.format if isinstance(f,str)
    375              else f.__getitem__)
--> 376         return self._new(map(g, self))
    377 
    378     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]
    256 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    221 
    222 # Cell

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/callback/core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     26 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in after_batch(self)
    421         if len(self.yb) == 0: return
    422         mets = self._train_mets if self.training else self._valid_mets
--> 423         for met in mets: met.accumulate(self.learn)
    424         if not self.training: return
    425         self.lrs.append(self.opt.hypers[-1]['lr'])

/usr/local/lib/python3.6/dist-packages/fastai2/metrics.py in accumulate(self, learn)
     33         targ = learn.y
     34         pred,targ = to_detach(pred),to_detach(targ)
---> 35         if self.flatten: pred,targ = flatten_check(pred,targ)
     36         self.preds.append(pred)
     37         self.targs.append(targ)

/usr/local/lib/python3.6/dist-packages/fastai2/torch_core.py in flatten_check(inp, targ)
    778     "Check that `out` and `targ` have the same number of elements and flatten them."
    779     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 780     test_eq(len(inp), len(targ))
    781     return inp,targ

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test_eq(a, b)
     30 def test_eq(a,b):
     31     "`test` that `a==b`"
---> 32     test(a,b,equals, '==')
     33 
     34 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test(a, b, cmp, cname)
     20     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     21     if cname is None: cname=cmp.__name__
---> 22     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     23 
     24 # Cell

AssertionError: ==:
128
64

Ah yes, forgot to pass along flatten=False (since preds and targets now have different shapes, that causes an error).

2 Likes

Now that this is solved, does it work properly?

Thanks again!
I’ve just tested it again installing from master (same nb as before - binary dataset - URLs.MNIST_TINY), but still get an error :slightly_frowning_face::

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-d2f47d35092f> in <module>()
      1 learn = cnn_learner(dls, resnet18, pretrained=False, metrics=[accuracy, APScore(), RocAuc()])
----> 2 learn.fit_one_cycle(5, 0.1)

29 frames
/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    199                         self.epoch=epoch;          self('begin_epoch')
    200                         self._do_epoch_train()
--> 201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')
    203                     finally:                       self('after_epoch')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    183             with torch.no_grad(): self.all_batches()
    184         except CancelValidException:                         self('after_cancel_validate')
--> 185         finally:                                             self('after_validate')
    186 
    187     @log_args(but='cbs')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    133 
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    374              else f.format if isinstance(f,str)
    375              else f.__getitem__)
--> 376         return self._new(map(g, self))
    377 
    378     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]
    256 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    221 
    222 # Cell

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/callback/core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     26 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in after_validate(self)
    436     def begin_validate(self): self._valid_mets.map(Self.reset())
    437     def after_train   (self): self.log += self._train_mets.map(_maybe_item)
--> 438     def after_validate(self): self.log += self._valid_mets.map(_maybe_item)
    439     def after_cancel_train(self):    self.cancel_train = True
    440     def after_cancel_validate(self): self.cancel_valid = True

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    374              else f.format if isinstance(f,str)
    375              else f.__getitem__)
--> 376         return self._new(map(g, self))
    377 
    378     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]
    256 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    221 
    222 # Cell

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _maybe_item(t)
    392 
    393 def _maybe_item(t):
--> 394     t = t.value
    395     return t.item() if isinstance(t, Tensor) and t.numel()==1 else t
    396 

/usr/local/lib/python3.6/dist-packages/fastai2/metrics.py in value(self)
     42         preds,targs = torch.cat(self.preds),torch.cat(self.targs)
     43         if self.to_np: preds,targs = preds.numpy(),targs.numpy()
---> 44         return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)
     45 
     46     @property

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in average_precision_score(y_true, y_score, average, pos_label, sample_weight)
    213                                 pos_label=pos_label)
    214     return _average_binary_score(average_precision, y_true, y_score,
--> 215                                  average, sample_weight=sample_weight)
    216 
    217 

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_base.py in _average_binary_score(binary_metric, y_true, y_score, average, sample_weight)
     75 
     76     if y_type == "binary":
---> 77         return binary_metric(y_true, y_score, sample_weight=sample_weight)
     78 
     79     check_consistent_length(y_true, y_score, sample_weight)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_uninterpolated_average_precision(y_true, y_score, pos_label, sample_weight)
    194             y_true, y_score, pos_label=1, sample_weight=None):
    195         precision, recall, _ = precision_recall_curve(
--> 196             y_true, y_score, pos_label=pos_label, sample_weight=sample_weight)
    197         # Return the step function integral
    198         # The following works because the last entry of precision is

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in precision_recall_curve(y_true, probas_pred, pos_label, sample_weight)
    671     fps, tps, thresholds = _binary_clf_curve(y_true, probas_pred,
    672                                              pos_label=pos_label,
--> 673                                              sample_weight=sample_weight)
    674 
    675     precision = tps / (tps + fps)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    538     check_consistent_length(y_true, y_score, sample_weight)
    539     y_true = column_or_1d(y_true)
--> 540     y_score = column_or_1d(y_score)
    541     assert_all_finite(y_true)
    542     assert_all_finite(y_score)

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in column_or_1d(y, warn)
    795         return np.ravel(y)
    796 
--> 797     raise ValueError("bad input shape {0}".format(shape))
    798 
    799 

ValueError: bad input shape (699, 2)

So if scikit-learn is not happy with the shape, it means it doesn’t really want the predictions. I don’t really understand what it wants then.

I’m waaaaay out of my comfort zone here and this is just pure speculation. But because ROC curves only work with binary labels I think sklearn wants only the probs of the item being a 1

Ok. Let me try to clarify what’s needed for a binary clasiffication problem.

I’ve trained the learner without RocAuc or APScore.

To get the metrics calculated manually I need this:

valid_probas, valid_targets, valid_preds = learn.get_preds(dl=dls.valid, with_decoded=True)
skm.average_precision_score(valid_targets, valid_probas[:, 1])
skm.roc_auc_score(valid_targets, valid_probas[:, 1])

So it needs the proba for the positive label.

This returns the correct values. What we were getting before was:

skm.average_precision_score(valid_targets, valid_preds)
skm.roc_auc_score(valid_targets, valid_preds)

This comes from sklearn roc_auc_score documentation:

Target scores. In the binary and multilabel cases, these can be either probability estimates or non-thresholded decision values (as returned by decision_function on some classifiers). In the multiclass case, these must be probability estimates which sum to 1. The binary case expects a shape (n_samples,), and the scores must be the scores of the class with the greater label. The multiclass and multilabel cases expect a shape (n_samples, n_classes).

But what do you do when there are more than two labels, then? This is what I don’t get.

I’ve only used it in binary classification.
@FraPochetti worked on a multi class one. He may be able to provide a better reply.
My understanding is that for multi class and multi label it requires a shape (n_samples, n_classes) plus pass multi_class=‘ovr’ to the API:
skim.roc_auc_score ( y_true , y_score , average=‘macro’ , sample_weight=None , max_fpr=None , multi_class=‘raise’ , labels=None ) like shown before.

This is getting way too problem-specific for a single API. It seems like the binary case will need a special metric to be handled and the current AUCRoc and Average Precision are only for the multi-label case.

I don’t know if it’ relevant but there was a working multi class version. Fastai v2 vision

Hi @sgugger,

@FraPochetti and I have been working together this morning to review the proba-based metrics issue in fastai2 (RocAuc and APScore), and have jointly come up with a proposal we’d like to submit to you.
It manages all possibilities sklearn allows while keeping the API consistent with the rest of fastai2 metrics.
We have tested our proposal vs sklearn’s API using this gist and everything works well.

In sklearn there are 3 scenarios for roc_auc_score (each of them calculated slightly differently):

  • Binary:

    • targets: shape = (n_samples, )
    • preds: pass through softmax and then [:, -1], shape = (n_samples,)
  • Multiclass:

    • targets: shape = (n_samples, )
    • preds: pass through softmax, shape = (n_samples, n_classes)
    • multi_class = ‘ovr’ or ‘ovo’ (1)
  • Multilabel:

    • targets: shape = (n_samples, n_classes)
    • preds: pass through sigmoid, shape = (n_samples, n_classes)

(1) ‘ovr’: average AUC of each class against the rest . 'ovo’ : average AUC of all possible pairwise combinations of classes.

sklearn’s average_precision_score implementation is restricted to binary or multilabel classification tasks. So it cannot be used in multiclass cases.

Here’s our proposal:

class AccumMetric(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, dim_argmax=None, sigmoid=False, softmax=False, proba=False, thresh=None, to_np=False, invert_arg=False,
                 flatten=True, **kwargs):
        store_attr(self,'func,dim_argmax,sigmoid,softmax,proba,thresh,flatten')
        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred.argmax(dim=self.dim_argmax) if (self.dim_argmax and not self.proba) else learn.pred
        if self.sigmoid: pred = torch.sigmoid(pred)
        if self.thresh:  pred = (pred >= self.thresh)
        if self.softmax: 
            pred = F.softmax(pred, dim=-1)
            if learn.dls.c == 2: pred = pred[:, -1]
        targ = learn.y
        pred,targ = to_detach(pred),to_detach(targ)
        if self.flatten: pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)

    @property
    def name(self):  return self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__

def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, softmax=False, proba=False, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    dim_argmax = axis if is_class and thresh is None else None
    sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)
    return AccumMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, softmax=softmax, proba=proba, thresh=thresh,
                       to_np=True, invert_arg=True, **kwargs)

def APScore(axis=-1, average='macro', pos_label=1, sample_weight=None):
    "Average Precision for binary single-label classification problems"
    return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, pos_label=pos_label, sample_weight=sample_weight)
    
def APScoreMulti(axis=-1, average='macro', pos_label=1, sample_weight=None):
    "Average Precision for multi-label classification problems"
    return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, sigmoid=True, proba=True,
                         average=average, pos_label=pos_label, sample_weight=sample_weight)
    
def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='raise', labels=None):
    "Area Under the Receiver Operating Characteristic Curve for single-label classification problems"
    """use default multi_class ('raise') for binary-class, and 'ovr'(average AUC of each class against the rest) 
    or 'ovo' (average AUC of all possible pairwise combinations of classes) for multi-class tasks"""
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, labels=labels)
    
def RocAucMulti(axis=-1, average='macro', sample_weight=None, max_fpr=None):
    "Area Under the Receiver Operating Characteristic Curve for multi-label classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, sigmoid=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)

Please, let us know if we can help you in any way with this.

3 Likes

This introduces a bit too much magic. I think there should be two names: BinaryRocAuc and RocAuc for the two separate metrics (that handle things differently).

Hi @sgugger,

Yes, @FraPochetti and I also discussed how the different cases should be grouped and named.

If we understand you correctly, you are proposing to split RocAuc into 2 to avoid the multi_class kwarg. That makes sense.

This would be our proposal for the 3 scenarios (gist with full code):

def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None):
    "Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)

def RocAucMultiClass(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr', labels=None):
    "Area Under the Receiver Operating Characteristic Curve for single-label multi-class classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, labels=labels)
    
def RocAucMulti(axis=-1, average='macro', sample_weight=None, max_fpr=None):
    "Area Under the Receiver Operating Characteristic Curve for multi-label classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, sigmoid=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)

As to the names we have a several options:

  • binary case: RocAuc or RocAucBinary and APScore
  • multi-class case: RocAucMultiClass (avg precision ot available in sklearn)
  • multi-label case: RocAucMulti or RocAucMultiLabel, and APScoreMulti

We believe RocAuc and RocAucMulti are consistent with all other fastai2 metrics. The new one would be RocAucMultiClass as multiclass in rocauc requires a different behavior.

1 Like

I disagree with the multi-class terminology. All metrics for single-label work with any number of labels, so the base RocAuc/APScore should work for the multi-label case. Since the binary case requires special behavior, it should be BinaryRocAuc and BinaryAPScore.

I think you meant:
"All metrics for single-label work with any number of classes, so the base RocAuc / APScore should work for the multi-class case.”
Right?

If so, it makes sense.
May I suggest just one thing. Can we use Binary as suffix instead of prefix? It’s easier to find the different RocAuc types when you start typing it using code completion?

This way it’d be:

  • RocAuc: for single-label multi-class
  • RocAucBinary or BinaryRocAuc/ APScoreBinary or BinaryAPScore: for single-label binary
  • RocAucMulti/ APSMulti: for multi-label

But it’s your call.

2 Likes

Yes I wanted to say multi-class, sorry.
No problem with having Binary as a suffix (since Multi is also a suffix).

2 Likes

Ok, good. So we agreed :sweat_smile:.

Here’s a gist with the code and the tests we used.

Here’s the code with agreed naming:

class AccumMetric(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, dim_argmax=None, sigmoid=False, softmax=False, proba=False, thresh=None, to_np=False, invert_arg=False,
                 flatten=True, **kwargs):
        store_attr(self,'func,dim_argmax,sigmoid,softmax,proba,thresh,flatten')
        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred.argmax(dim=self.dim_argmax) if (self.dim_argmax and not self.proba) else learn.pred
        if self.sigmoid: pred = torch.sigmoid(pred)
        if self.thresh:  pred = (pred >= self.thresh)
        if self.softmax: 
            pred = F.softmax(pred, dim=-1)
            if learn.dls.c == 2: pred = pred[:, -1]
        targ = learn.y
        pred,targ = to_detach(pred),to_detach(targ)
        if self.flatten: pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)

    @property
    def name(self):  return self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__

def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, softmax=False, proba=False, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    dim_argmax = axis if is_class and thresh is None else None
    sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)
    return AccumMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, softmax=softmax, proba=proba, thresh=thresh,
                       to_np=True, invert_arg=True, **kwargs)

def APScore(axis=-1, average='macro', pos_label=1, sample_weight=None):
    "Average Precision for binary single-label classification problems"
    return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, pos_label=pos_label, sample_weight=sample_weight)
    
def APScoreMulti(axis=-1, average='macro', pos_label=1, sample_weight=None):
    "Average Precision for multi-label classification problems"
    return skm_to_fastai(skm.average_precision_score, axis=axis, flatten=False, sigmoid=True, proba=True,
                         average=average, pos_label=pos_label, sample_weight=sample_weight)
    
def RocAucBinary(axis=-1, average='macro', sample_weight=None, max_fpr=None):
    "Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)

def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr', labels=None):
    "Area Under the Receiver Operating Characteristic Curve for single-label multi-class classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, softmax=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class, labels=labels)
    
def RocAucMulti(axis=-1, average='macro', sample_weight=None, max_fpr=None):
    "Area Under the Receiver Operating Characteristic Curve for multi-label classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis, flatten=False, sigmoid=True, proba=True,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr)

Will you add update this in fastai2 then? Is there anything else you need from @FraPochetti or me?

1 Like