Discrepancy with proba-based metrics between fastai2 and sklearn

Thanks for your follow up @FraPochetti!
I’ve created a simple gist to explain the issue.
I noticed it because the metrics I was getting when I started working with fastai2 where significantly worse that with v1.

Please, let me know if this is not clear.

@muellerzr, could you please take a quick look at this? Have you noticed any issue when using RocAuc in v2?

1 Like

This is indeed super weird…
The difference in my case (using this in your gist: learn.fit_one_cycle(3, 0.2)) is huge!


This seems like an sklearn issue though.
As you can see we are passing 2 almost identical arrays in terms of values.
Of course valid_preds contains either 0 or 1 (I also tried with valid_preds.float() and the result is the same), whereas valid_probas contains actual floats. Rather extreme, as they are either very close to 0 or very close to 1, but still floats.
Can this be driven just by a rounding issue?
At the end of the day ROC AUC (and Precision) are calculated computing TP, TN, FP and FN at varying thresholds. Thing is, with such extreme values, these thresholds never make a difference on the underlying metrics.
Especially when the threshold is compared to 0 or 1.
It could return something funny though, when compared to 2.288320e-42, hence the rounding issue.
I have to admit I am not sure though.
@lgvaz we need your expertise too!
Are we missing anything super dumb here?

Btw, have you tried stepping out of fastai and just trying reproducing the same issue with a toy sklearn example (make_ckassification)?
I am planning to test that tomorrow.

@oguiza, I made a quick test in a separate env and got the same results.
This was just to prove nothing was somehow screwed up with the sklearn installation in fastai, as I had claimed this could be a sklearn-related problem.


I think you are right in saying this is a fastai bug.
It is rather obvious that passing targets alongside 0 and 1, or floats probabilities, return different results.
It seems fastai uses the former, while it should use the latter.

I think this is driven by the fact that sigmoid=None in skm_to_fastai (here), which means pred is calculated with argmax (here) instead of torch.sigmoid (here).

@sgugger, are we getting anything wrong here?

2 Likes

Great analysis @FraPochetti! Fully agree with what you’ve done.

This is v1 code for rocauc and it uses probas in the correct way (for binary classification). That’s why my results were worse with v2.

@dataclass
class AUROC(Callback):
    "Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks."
    def on_epoch_begin(self, **kwargs):
        self.targs, self.preds = LongTensor([]), Tensor([])
        
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = F.softmax(last_output, dim=1)[:,-1]
        self.preds = torch.cat((self.preds, last_output.cpu()))
        self.targs = torch.cat((self.targs, last_target.cpu().long()))
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))

As a bonus, It’d be great to also use the native support sklearn provides for multiclass and multilabel (for roc_auc only).

1 Like

This is how I (well @ilovescience did :smiley: ) fixed this problem (in a multi-class context). Now that I think about it, it works also in the binary case :slight_smile: , e.g. it should fix your problem too.

def _accumulate(self, learn):
    pred = learn.pred
    if self.sigmoid: pred = torch.nn.functional.softmax(pred, dim=1) #hack for roc_auc_score
    if self.thresh:  pred = (pred >= self.thresh)
    targ = learn.y
    pred,targ = to_detach(pred),to_detach(targ)
    if self.flatten: pred,targ = flatten_check(pred,targ)
    self.preds.append(pred)
    self.targs.append(targ)

AccumMetric.accumulate = _accumulate

def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None,multi_class='ovr'):
    "Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems"
    return skm_to_fastai(skm.roc_auc_score, axis=axis,
                         average=average, sample_weight=sample_weight, max_fpr=max_fpr,
                         flatten=False,multi_class=multi_class,sigmoid=True)
1 Like

Yep, that makes sense! I will use it in the interim until we have a solution to this :slight_smile:
I’d prefer not to modify the AccumMetric.accumulate function as it might have an impact on other metrics.
Let’s see what Sylvain thinks of all this!

I don’t understand anything to what you are saying. The _accumulate function is wrong: why use softmax when you ask for sigmoid? That doesn’t make any sense.

The accumulate function in fastai has a dim_argmax you can pass for softmax. Maybe what you are saying is that this argument should be used to wrap the roc metric?

First, thanks for taking the time to reply.

The _accumulate function does exactly what I want and it returns the correct outputs. It was “hacked” in the context of ROC AUC for a multi-class problem (hence softmax) in which I needed probabilities for all classes, not just argmax. It is very likely not the most elegant, nor the most efficient way of achieving the final goal, but this is a whole different story. It was also my very first interaction with Callbacks and custom Metrics in fastai2, so I was definitely not aware of all my options. I just wanted it done.

As for the (potential) problem @oguiza originally reported, what we are saying is that the default RocAuc implementation for binary classification seems not to perform sigmoid under the hood, passing to sklearn predictions which are not probabilities, but 1 and 0 (e.g. the result of predict instead of predit_proba in sklearn jargon).

So, in a nutshell, our question is: how can we call RocAuc inside a learner, and make it calculate sigmoid on model’s outputs before passing them to sklearn’s roc_auc_score? Maybe cnn_learner(dls, arch, metrics=RocAuc(sigmoid=True))?

As suggested in the below gist, it seems the default is to pass predictions and not probabilities.

Once again, it might be we are getting this all wrong.
If this is the case, we apologize.

1 Like

I think I understand a bit better. Basically you need to have some behavior where instead of taking the argmax, you just want the softmax that returns all probabilities. What confuses me in your posts is that you keep talking about sigmoid, but you don’t want to apply that, you want a softmax on a certain dimension. This means we need to add a softmax argument to skm_to_fastai not change the current behavior (otherwise it would break all multi-label metrics).

Which other metrics take those probabilities instead of predictions while I’m at it?

I could definitely have phrased the whole thing better :smiley:. Sorry about that.

I think ROC AUC, Precision and Recall expect probabilities, as those metrics are based on applying different thresholds and check how False Positives, True Positives, etc, change.

The below AUROC implementation from fastai1 makes sense to me, as feeds targets and probabilities (e.g. F.softmax(last_output, dim=1)[:,-1]) to auc_roc_score.
Isn’t feeding predictions just wrong?
auc_roc_score would “see” 0 and 1 and use them as probabilities, messing everything up.

@dataclass
class AUROC(Callback):
    "Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks."
    def on_epoch_begin(self, **kwargs):
        self.targs, self.preds = LongTensor([]), Tensor([])
        
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = F.softmax(last_output, dim=1)[:,-1]
        self.preds = torch.cat((self.preds, last_output.cpu()))
        self.targs = torch.cat((self.targs, last_target.cpu().long()))
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))
1 Like

The whole point of roc_auc_score is to check how FPR and TPR change, with varying proba thresholds. So roc_auc_score needs probas, not predictions.
I hope I did not mess up here :smiley:

As far as I know it’s only rocauc and average precision. Precision and Recall expect press, not probas.
You can check this here and look for those metrics that indicate y_score instead of y_pred.

1 Like

Ok, will work on a fix and apply it to RocAuc and average precision then. I’ll post here once it’s ready.

2 Likes

Great! Thanks Sylvain and Francesco for your support!

Thanks Sylvain. Please let us know if we can provide any support!

Ok, I pushed something on master. Could you check it works properly now?

1 Like

Wow super fast!
I think you got a typo here though: softamx instead of softmax

Woops, fixed now.

Thanks Sylvain for being so fast!! :smiley:

I’ve tested it using the same gist again, but I get the following error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-4-d2f47d35092f> in <module>()
      1 learn = cnn_learner(dls, resnet18, pretrained=False, metrics=[accuracy, APScore(), RocAuc()])
----> 2 learn.fit_one_cycle(5, 0.1)

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    199                         self.epoch=epoch;          self('begin_epoch')
    200                         self._do_epoch_train()
--> 201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')
    203                     finally:                       self('after_epoch')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    181         try:
    182             self.dl = dl;                                    self('begin_validate')
--> 183             with torch.no_grad(): self.all_batches()
    184         except CancelValidException:                         self('after_cancel_validate')
    185         finally:                                             self('after_validate')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    154 
    155     def one_batch(self, i, b):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in one_batch(self, i, b)
    165             self.opt.zero_grad()
    166         except CancelBatchException:                         self('after_cancel_batch')
--> 167         finally:                                             self('after_batch')
    168 
    169     def _do_begin_fit(self, n_epoch):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    133 
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    374              else f.format if isinstance(f,str)
    375              else f.__getitem__)
--> 376         return self._new(map(g, self))
    377 
    378     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]
    256 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    221 
    222 # Cell

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/callback/core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     26 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in after_batch(self)
    421         if len(self.yb) == 0: return
    422         mets = self._train_mets if self.training else self._valid_mets
--> 423         for met in mets: met.accumulate(self.learn)
    424         if not self.training: return
    425         self.lrs.append(self.opt.hypers[-1]['lr'])

/usr/local/lib/python3.6/dist-packages/fastai2/metrics.py in accumulate(self, learn)
     33         targ = learn.y
     34         pred,targ = to_detach(pred),to_detach(targ)
---> 35         if self.flatten: pred,targ = flatten_check(pred,targ)
     36         self.preds.append(pred)
     37         self.targs.append(targ)

/usr/local/lib/python3.6/dist-packages/fastai2/torch_core.py in flatten_check(inp, targ)
    778     "Check that `out` and `targ` have the same number of elements and flatten them."
    779     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 780     test_eq(len(inp), len(targ))
    781     return inp,targ

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test_eq(a, b)
     30 def test_eq(a,b):
     31     "`test` that `a==b`"
---> 32     test(a,b,equals, '==')
     33 
     34 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test(a, b, cmp, cname)
     20     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     21     if cname is None: cname=cmp.__name__
---> 22     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     23 
     24 # Cell

AssertionError: ==:
128
64