Custom Metrics (FP/TN and FN/TP) in fastaiv2

Hi all,

I had a question on custom metrics with fastai2. I want to calculate FP/TN and FN/TP as separate metrics that I will then weight and combine.

I started with:

x,y = dls.one_batch()
preds = learn.model(x)

But, I’m not sure how to proceed from here.

I can get these from the confusion matrix, but I’m not sure how to connect Interpretation object with my learner object.

1 Like

You won’t be able to use the Interpret module, at least not if you want a metric while training. You could do something like what is described here in this post (to which then you just scale TP, FP, TN, FN to your liking:

1 Like

Thanks Zach, that’s funny I was looking at that exact post before posting here.
Only thing is they are using for-loops, and I wanted to use numpy/torch for better performance.

I wanted to get my head around building this by using one_batch, so I can take this step by step. I think this is a good tutorial in the making.

If there were a way to get the confusion matrix at this point (during training), I think it is pretty straightforward from there.

1 Like

Sure, it’s fairly straightforward to use.

So I made a quick tabular_learner based on adults to test it on. Let’s follow step by step what this could look like:

First we’ll get the raw outputs from our model:

batch = next(iter(dls[0]))
learn = tabular_learner(dls, layers=[200,100])
out = learn.model(*batch[:2])

Next we’ll want to argmax them, as we want to grab the label itself, not simply a softmax:

preds = out.argmax(dim=1)

Finally, for our ground truth. We’ll want to change it’s shape to that similar of preds (IE instead of [512,1] like it normally is, we want [512])

truth.view(-1)

Now we have everything in place to use sklearn's confusion_matrix:

from sklearn.metrics import confusion_matrix
tn, fp, fn, tp = confusion_matrix(truth.cpu().numpy(), preds.detach().cpu().numpy()).ravel()

Now let’s talk efficiency. On a batch size of 512, in total that took 1.26 ms, so it’s a start but it’s not very fast, with the bulk of it being the confusion matrix

1 Like

Thanks Zach!

I did the following for my cnn_learner:

x,y = dls.one_batch()
out = learn.model(x)
preds = out.argmax(dim=1)

from sklearn.metrics import confusion_matrix
tn, fp, fn, tp = confusion_matrix(y.cpu().numpy(), preds.detach().cpu().numpy()).ravel()

I have the values now!

Now, how do I make a metric out of this?

I’m looking at this example:

def _exp_rmspe(inp,targ):
    inp,targ = torch.exp(inp),torch.exp(targ)
    return torch.sqrt(((targ - inp)/targ).pow(2).mean())
exp_rmspe = AccumMetric(_exp_rmspe)

Let me have a go and see if this works.

OK, I did this:

def fp_tn(inp,targ):
  tn, fp, fn, tp = confusion_matrix(targ.cpu().numpy(), inp.detach().cpu().numpy()).ravel()
  return fp/tn

fptn_metric = AccumMetric(fp_tn)

But, I got an error after 1 epoch:

epoch	train_loss	valid_loss	accuracy	f1_score	fp_tn	time
0	0.392312	1.395006	0.796875	0.480000	None	00:51
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:2854: UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor will change "
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-63-dfea2fe8c944> in <module>()
----> 1 learn.fine_tune(1, 3e-2)

23 frames
/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
    159     "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
    160     self.freeze()
--> 161     self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
    162     base_lr /= 2
    163     self.unfreeze()

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    114 
    115 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/utils.py in _f(*args, **kwargs)
    429         init_args.update(log)
    430         setattr(inst, 'init_args', init_args)
--> 431         return inst if to_return else f(*args, **kwargs)
    432     return _f
    433 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    202                         self.epoch=epoch;          self('begin_epoch')
    203                         self._do_epoch_train()
--> 204                         self._do_epoch_validate()
    205                     except CancelEpochException:   self('after_cancel_epoch')
    206                     finally:                       self('after_epoch')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _do_epoch_validate(self, ds_idx, dl)
    181         try:
    182             self.dl = dl;                                    self('begin_validate')
--> 183             with torch.no_grad(): self.all_batches()
    184         except CancelValidException:                         self('after_cancel_validate')
    185         finally:                                             self('after_validate')

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    154 
    155     def one_batch(self, i, b):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in one_batch(self, i, b)
    165             self.opt.zero_grad()
    166         except CancelBatchException:                         self('after_cancel_batch')
--> 167         finally:                                             self('after_batch')
    168 
    169     def _do_begin_fit(self, n_epoch):

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in __call__(self, event_name)
    132     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    133 
--> 134     def __call__(self, event_name): L(event_name).map(self._call_one)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    375              else f.format if isinstance(f,str)
    376              else f.__getitem__)
--> 377         return self._new(map(g, self))
    378 
    379     def filter(self, f, negate=False, **kwargs):

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    325     @property
    326     def _xtra(self): return None
--> 327     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    328     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    329     def copy(self): return self._new(self.items.copy())

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     45             return x
     46 
---> 47         res = super().__call__(*((x,) + args), **kwargs)
     48         res._newchk = 0
     49         return res

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    316         if items is None: items = []
    317         if (use_list is not None) or not _is_array(items):
--> 318             items = list(items) if use_list else _listify(items)
    319         if match is not None:
    320             if is_coll(match): match = len(match)

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in _listify(o)
    252     if isinstance(o, list): return o
    253     if isinstance(o, str) or _is_array(o): return [o]
--> 254     if is_iter(o): return list(o)
    255     return [o]
    256 

/usr/local/lib/python3.6/dist-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    218             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    219         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 220         return self.fn(*fargs, **kwargs)
    221 
    222 # Cell

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name)
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.6/dist-packages/fastai2/callback/core.py in __call__(self, event_name)
     22         _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
     23                (self.run_valid and not getattr(self, 'training', False)))
---> 24         if self.run and _run: getattr(self, event_name, noop)()
     25         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     26 

/usr/local/lib/python3.6/dist-packages/fastai2/learner.py in after_batch(self)
    427         if len(self.yb) == 0: return
    428         mets = self._train_mets if self.training else self._valid_mets
--> 429         for met in mets: met.accumulate(self.learn)
    430         if not self.training: return
    431         self.lrs.append(self.opt.hypers[-1]['lr'])

/usr/local/lib/python3.6/dist-packages/fastai2/metrics.py in accumulate(self, learn)
     44         targ = learn.y
     45         pred,targ = to_detach(pred),to_detach(targ)
---> 46         if self.flatten: pred,targ = flatten_check(pred,targ)
     47         self.preds.append(pred)
     48         self.targs.append(targ)

/usr/local/lib/python3.6/dist-packages/fastai2/torch_core.py in flatten_check(inp, targ)
    778     "Check that `out` and `targ` have the same number of elements and flatten them."
    779     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 780     test_eq(len(inp), len(targ))
    781     return inp,targ

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test_eq(a, b)
     30 def test_eq(a,b):
     31     "`test` that `a==b`"
---> 32     test(a,b,equals, '==')
     33 
     34 # Cell

/usr/local/lib/python3.6/dist-packages/fastcore/test.py in test(a, b, cmp, cname)
     20     "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     21     if cname is None: cname=cmp.__name__
---> 22     assert cmp(a,b),f"{cname}:\n{a}\n{b}"
     23 
     24 # Cell

AssertionError: ==:
128
64

OK, some reason the length of my inp and targ are different.
I’m running %debug and this is what I’m seeing

> /usr/local/lib/python3.6/dist-packages/fastai2/metrics.py(46)accumulate()
     44         targ = learn.y
     45         pred,targ = to_detach(pred),to_detach(targ)
---> 46         if self.flatten: pred,targ = flatten_check(pred,targ)
     47         self.preds.append(pred)
     48         self.targs.append(targ)

ipdb> pred.shape
torch.Size([64, 2])
ipdb> targ.shape
torch.Size([64])
ipdb> d
> /usr/local/lib/python3.6/dist-packages/fastai2/torch_core.py(780)flatten_check()
    777 def flatten_check(inp, targ):
    778     "Check that `out` and `targ` have the same number of elements and flatten them."
    779     inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
--> 780     test_eq(len(inp), len(targ))
    781     return inp,targ

ipdb> inp.shape
torch.Size([128])
ipdb> targ.shape
torch.Size([64])

Oh, I need to argmax the pred.

1 Like

You missed the fact we need to argmax() our predictions before sending it to the confusion matrix (as remember, our model outputs raw probabilities that don’t sum to 1)

Just saw you figured that out :wink:

1 Like

Yes! I had just figured this out as you were replying! :smiley:

I was using %debug the wrong way all this while, I didn’t know I could debug after the cell. The u and d really helped!

OK, but I’m still seeing the same issue, though :confused:

This is my func definition for the metric now with the argmax added.

def fp_tn(inp,targ):
  preds = inp.argmax(dim=1)
  tn, fp, fn, tp = confusion_matrix(targ.cpu().numpy(), preds.detach().cpu().numpy()).ravel()
  return fp/tn

I think I need to define something with AccuMetric for this, because I see:

> /usr/local/lib/python3.6/dist-packages/fastai2/metrics.py(46)accumulate()
     44         targ = learn.y
     45         pred,targ = to_detach(pred),to_detach(targ)
---> 46         if self.flatten: pred,targ = flatten_check(pred,targ)
     47         self.preds.append(pred)
     48         self.targs.append(targ)

ipdb> pred.shape
torch.Size([64, 2])

I need to override the accumulate function looks like

See my edit, I described the process but did not show code. Just tested this myself and it worked:

def tp_tn(inp, targ):
    preds = inp.argmax(dim=1)
    tn, fp, _,_ = confusion_matrix(targ.cpu().view(-1).numpy(), preds.detach().cpu().numpy()).ravel()
    return fp/tn

(There’s also no need to wrap it in an AccuMetric, otherwise that will cause an error)

1 Like

I came to it in a different way.

def fp_tn(inp,targ):
  tn, fp, fn, tp = confusion_matrix(targ.cpu().numpy(), inp.detach().cpu().numpy()).ravel()
  return fp/tn

fptn_metric = AccumMetric(fp_tn, dim_argmax=1)

OK now, I am wondering what the poslabel (similar to how we define it for F1Score) for the confusion matrix is.

I have two classes [0,1]

And, I want my pos_label=0, so that if something is labeled as 0 and it is not, it is a FP.

I’m poring over the sklearn docs right now for this.

1 Like

I’m not 100% sure that will translate easily, considering we can have a confusion matrix of nxn, so curious on what you find :slight_smile:

There is of course the easiest answer. Fastai’s version (decoded is an argmax’d), which of course I forgot can be done independently:

def confusion_matrix(self):
        "Confusion matrix as an `np.ndarray`."
        x = torch.arange(0, len(self.vocab))
        d,t = flatten_check(self.decoded, self.targs)
        cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
        return to_np(cm)
1 Like

Yes, I had looked at this originally, and which was why I was asking about the Interpretation class, because I was not sure how to decode what decodes was…

I just tested them both on a single-batch, and the fastai version comes in at 0.10 ms quicker, all three times. :slight_smile:

And, on the pos_labels part, I confirmed that the default is [0,1], and it considers 1 as the positive label.

And, to invert it, I just passed labels=[1,0] to the confusion_matrix call for sklearn. Not sure how to do that for the fastai version though. Also, I haven’t yet extended this to multi-class scenario.

tn, fp, fn, tp = confusion_matrix(y.cpu().numpy(), preds.detach().cpu().numpy(), labels=[1,0]).ravel()

Thanks a lot for your help, Zach!

2 Likes