Accumulating Gradients

I’ve made a first try around a BatchNorm module that can aggregate results over a few batch and only updates it’s running mean var when we tell it to (with the mean and var across all the batches it saw). It’s not at all optimized, and pretrained models would need to be converted to use that if it works, but we can see if it fixes the problem first.

Here is the class:

class AccumulateBatchNorm(nn.Module):
    
    def __init__(self, bn_class, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.track_running_stats,self.momentum = track_running_stats,momentum
        self.bn = bn_class(num_features, eps=eps, momentum=momentum, affine=affine,
                 track_running_stats=track_running_stats)
        self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_running_stats(self):
        self.running_mean,self.running_square,self.iterations = None,None,None
        self.bn.reset_running_stats()
    
    def update_stats(self):
        if self.training and self.track_running_stats:
            self.bn.num_batches_tracked += 1
            eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
            self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
            var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
            self.bn.running_var  = self.bn.running_var  * (1-eaf) + var  * eaf
            self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_parameters(self):
        self.bn.reset_parameters()
        
    def forward(self, input):
        self.bn._check_input_dim(input)
        if self.track_running_stats:
            if self.iterations is None:
                self.running_mean   = self.bn.weight.new_zeros(self.num_features)
                self.running_square = self.bn.weight.new_zeros(self.num_features)
                self.iterations   = 0
            self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
            self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
            self.iterations += input.size(0)
        return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var, 
            False, 0., self.bn.eps, torch.backends.cudnn.enabled)   

The use is to pass the BatchNorm class we want and then the regular params. For instance:

tst = AccumulateBatchNorm(nn.BatchNorm2d, 64)

Then we can make the batches pass like a regular module, until we want to update the stats, which is done with tst.update_states() (the other functions are just there to make this compatible with other BatchNorm modules):

y = tst(x1)
y = tst(x2)
y = tst(x3)
tst.update_stats()

The result is an update as if we had passed the batch torch.cat([x1,x2,x3],0)

So @kcturgutlu, the Callback would need to be update to scan through the model and call this function on every AccumulateBatchNorm layer when you do a step. If you can experiment to check it yields the same results (when using a model with AccumulateBatchNorm with gradient accumulation vs a model with BatchNorm and no gradient accumulation), we can go on the next step and try to optimize that thing.

6 Likes