@sgugger need some help here 
So I created a callback which will handle the above issues but I don’t know how to best integrate monkey patching in terms user experience.
Here is the implementation:
class AccumulateOptimWrapper(OptimWrapper):
def step(self): pass
def zero_grad(self): pass
def real_step(self): super().step()
def real_zero_grad(self): super().zero_grad()
def acc_create_opt(self, lr:Floats, wd:Floats=0.):
"Create optimizer with `lr` learning rate and `wd` weight decay."
self.opt = AccumulateOptimWrapper.create(self.opt_func, lr, self.layer_groups,
wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
@dataclass
class AccumulateStep(LearnerCallback):
"""
Does accumlated step every nth step by accumulating gradients
"""
def __init__(self, learn:Learner, n_step:int = 1):
super().__init__(learn)
self.n_step = n_step
def on_train_begin(self, **kwargs):
"check if loss is reduction"
if self.loss_func.reduction == "mean":
print("For better gradients consider 'reduction=sum'")
def on_epoch_begin(self, **kwargs):
"init samples and batches, change optimizer"
self.acc_samples = 0
self.acc_batches = 0
def on_batch_begin(self, last_input, last_target, **kwargs):
"accumulate samples and batches"
self.acc_samples += last_input.shape[0]
self.acc_batches += 1
print(f"At batch {self.acc_batches}")
def on_backward_end(self, **kwargs):
"step if number of desired batches accumulated, reset samples"
if (self.acc_batches % self.n_step) == 0:
for p in (self.learn.model.parameters()):
if p.requires_grad: p.grad.div_(self.acc_samples)
print(f"Stepping at batch: {self.acc_batches}")
self.learn.opt.real_step()
self.learn.opt.real_zero_grad()
self.acc_samples = 0
def on_epoch_end(self, **kwargs):
"step the rest of the accumulated grads"
self.learn.opt.real_step()
self.learn.opt.real_zero_grad()
Thanks !