Hi all
please see if this the right way to implement grad acc call back
from fastai.torch_core import *
from fastai.callbacks import *
from fastai.basic_train import *
from torch.autograd import Variable
class GradAccumulation(LearnerCallback):
def __init__(self,learn:Learner,num):
super().__init__(learn)
self.total_batch_num=num
print('total_batch_num',num)
def on_epoch_begin(self, **kwargs):
#super().__init__(learn)
self.n=0
self.total_loss=0
#self.loss=0
self.last_n=0
self.steps=2
def on_backward_begin(self,last_loss,**kwargs):
print(self.n)
if (self.n+1)%self.steps==0 :
#print('n',self.n)
print('batch',self.n)
self.total_loss+=last_loss
self.total_loss=self.total_loss/2
skip=False
else :
self.total_loss+=last_loss
if ( self.n==self.total_batch_num):
print('last batch')
skip=False
else :
skip=True
self.last_n=self.n
self.n+=1
#self.total_loss=last_loss/2
#skip=True
return {'last_loss':self.total_loss,'skip_bwd':skip}
def on_backward_end(self,**kwargs):
if (self.last_n+1)%self.steps==0:
self.total_loss=0
skip=False
else:
if ( self.last_n==self.total_batch_num):
print('last batch')
skip=False
else :
skip=True
return {'skip_bwd':skip}
def on_step_end(self,**kwargs):
if (self.last_n+1)%self.steps==0:
#self.total_loss=0
skip=False
else:
if ( self.last_n==self.total_batch_num):
print('last batch')
skip=False
else:
skip=True
return {'skip_bwd':skip}