Validate Gradient Accumulation Implementation

hi
I have tried to implement Gradient accumulation(every 2 batches) using callbacks class in callback04.ipynb.Please help validate if its correct implementation

class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn)
        self.n_iters = 0
        self.total_loss=0
        self.nv= len(learn.data.valid_dl)
        self.nt= len(learn.data.train_dl)
        return True
        
    def after_step(self):
      
        #self.n_iters += 1
        print('step',self.n_iters)
        if self.n_iters>=10: learn.stop = True
        return True
   
    def after_loss(self,loss): #invoked from one_batch after loss
      
      
      
      super().after_loss(loss)
      self.n_iters += 1
        
      print('l',self.n_iters)
      if self.n_iters%2==0:
          
          
          self.loss=(self.total_loss+self.loss)/2
          print('train_loss',self.loss)
          self.total_loss=0
          return True
      else:
          
          self.total_loss+=self.loss
          
      return False

what results are you getting ?

here is what i did.What i am not sure is ,inside one batch when I doo opt.step will it take gradient of self.loss.backward( ) Test call back ??
here is excerpt of output. Step done at even epochs

train_loss tensor(0.4965, grad_fn=<AddBackward0>)
step 6 l 7 l
8 train_loss tensor(0.5814, grad_fn=<AddBackward0>) step 8 l 9 l
10 train_loss tensor(0.3792, grad_fn=<AddBackward0>) step 10 0 tensor(0.3318) tensor(0.9025)

def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    #loss.backward() commented out here
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

Doing below things in testcallbacks . Implemented Afterloss function

class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn)
        self.n_iters = 0
        self.total_loss=0
        self.nv= len(learn.data.valid_dl)
        self.nt= len(learn.data.train_dl)
        return True
        
    def after_step(self):
      
        #self.n_iters += 1
        print('step',self.n_iters)
        if self.n_iters>=10: learn.stop = True
        return True
   
    def after_loss(self,loss):
      
      
      
      super().after_loss(loss)
      self.n_iters += 1
        
      print('l',self.n_iters)
      if self.n_iters%2==0:
          
          #print('train_loss')
          self.loss=self.total_loss*0.9+self.loss*0.1 
          self.loss.backward()
          print('train_loss',self.loss)
          self.total_loss=0
          return True
      else:
          
          self.total_loss+=self.loss
          
      return False

fit function

def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): 
              tot_loss,tot_acc = 0.,0.
              for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
              
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()