Accumulating Gradients

@sgugger need some help here :slight_smile:

So I created a callback which will handle the above issues but I don’t know how to best integrate monkey patching in terms user experience.

Here is the implementation:

class AccumulateOptimWrapper(OptimWrapper):
    def step(self):          pass
    def zero_grad(self):      pass
    def real_step(self):      super().step()
    def real_zero_grad(self): super().zero_grad()
        
def acc_create_opt(self, lr:Floats, wd:Floats=0.):
        "Create optimizer with `lr` learning rate and `wd` weight decay."
        self.opt = AccumulateOptimWrapper.create(self.opt_func, lr, self.layer_groups,
                                         wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
        
@dataclass
class AccumulateStep(LearnerCallback):
    """
    Does accumlated step every nth step by accumulating gradients
    """
    def __init__(self, learn:Learner, n_step:int = 1):
        super().__init__(learn)
        self.n_step = n_step
 
    def on_train_begin(self, **kwargs):
        "check if loss is reduction"
        if self.loss_func.reduction == "mean":
             print("For better gradients consider 'reduction=sum'")
        
    def on_epoch_begin(self, **kwargs):
        "init samples and batches, change optimizer"
        self.acc_samples = 0
        self.acc_batches = 0
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"
        self.acc_samples += last_input.shape[0]
        self.acc_batches += 1
        print(f"At batch {self.acc_batches}")
        
    def on_backward_end(self, **kwargs):
        "step if number of desired batches accumulated, reset samples"
        if (self.acc_batches % self.n_step) == 0:
            for p in (self.learn.model.parameters()):
                if p.requires_grad: p.grad.div_(self.acc_samples)
    
            print(f"Stepping at batch: {self.acc_batches}")
            self.learn.opt.real_step()
            self.learn.opt.real_zero_grad()
            self.acc_samples = 0
    
    def on_epoch_end(self, **kwargs):
        "step the rest of the accumulated grads"
        self.learn.opt.real_step()
        self.learn.opt.real_zero_grad()

Thanks !

4 Likes