Hard example mining callback

Hi,

I create a hard example mining callback for fastai2. Do you find useful? If so, I’ll create a PR to fastai2 (adding docs) after asking to sgugger or jeremy if they are willing to include it into the fastai2. If not, I may create a repository.

from fastai2.vision.all import *

class HEM(Callback):
    run_after,run_valid = [Normalize],False
    
    def __init__(self, top_k=0.5): self.top_k = top_k
        
    def begin_fit(self):
        self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
        
    def after_fit(self):
        self.learn.loss_func = self.old_lf

    def lf(self, pred, *yb):
        if not self.training: return self.old_lf(pred, *yb)
        # Select top_k samples to keep. If it's between [0,1), means a percentage of batch size
        top_k = self.top_k if self.top_k >= 1 else round(pred.shape[0] * self.top_k)
        with NoneReduce(self.old_lf) as lf:
            losses = lf(pred,*yb)
            top_losses = losses.topk(top_k, sorted=False)[0]
            
        return reduce_loss(top_losses, getattr(self.old_lf, 'reduction', 'mean'))

Usage: like any other callback :wink: :

Learner(cbs=HEM(top_k=.5)) (backpropagate top 50% losses) or Learner(cbs=HEM(top_k=5)) (backpropagate top 5 losses)

I hope that you find it useful :smiley:

2 Likes