Please validate the grad acc call back

Hi all
please see if this the right way to implement grad acc call back

from fastai.torch_core import *
from fastai.callbacks import *
from fastai.basic_train import *
from torch.autograd import Variable

class GradAccumulation(LearnerCallback):
    def __init__(self,learn:Learner,num):
        super().__init__(learn)
        self.total_batch_num=num
        print('total_batch_num',num)
    def on_epoch_begin(self, **kwargs):
        #super().__init__(learn)
        self.n=0
        self.total_loss=0
        #self.loss=0
        self.last_n=0
        self.steps=2
    def on_backward_begin(self,last_loss,**kwargs):
        print(self.n)
        if (self.n+1)%self.steps==0 :
            #print('n',self.n)
            print('batch',self.n)
            self.total_loss+=last_loss
            self.total_loss=self.total_loss/2
            skip=False
        else :
            
            self.total_loss+=last_loss
            if ( self.n==self.total_batch_num):
                print('last batch')
                skip=False
            else :
                
                skip=True
        
       
        self.last_n=self.n    
        self.n+=1
        #self.total_loss=last_loss/2
        #skip=True
        return {'last_loss':self.total_loss,'skip_bwd':skip}
    def on_backward_end(self,**kwargs):
        if (self.last_n+1)%self.steps==0:
            self.total_loss=0
            
            skip=False
        else:
            
            if ( self.last_n==self.total_batch_num):
                print('last batch')
                skip=False
            else :
                
                
                skip=True
            
        return {'skip_bwd':skip}
    
    def on_step_end(self,**kwargs):
        if (self.last_n+1)%self.steps==0:
            
            #self.total_loss=0
            skip=False
        else:
            
            
            if ( self.last_n==self.total_batch_num):
                print('last batch')
                skip=False
            else:
                
                
                skip=True
        return {'skip_bwd':skip}