Accumulating Gradients

It seems Pytorch already has a builtin class for Group Normalization:
https://pytorch.org/docs/stable/nn.html#groupnorm

1 Like

Iā€™ve made a first try around a BatchNorm module that can aggregate results over a few batch and only updates itā€™s running mean var when we tell it to (with the mean and var across all the batches it saw). Itā€™s not at all optimized, and pretrained models would need to be converted to use that if it works, but we can see if it fixes the problem first.

Here is the class:

class AccumulateBatchNorm(nn.Module):
    
    def __init__(self, bn_class, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.track_running_stats,self.momentum = track_running_stats,momentum
        self.bn = bn_class(num_features, eps=eps, momentum=momentum, affine=affine,
                 track_running_stats=track_running_stats)
        self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_running_stats(self):
        self.running_mean,self.running_square,self.iterations = None,None,None
        self.bn.reset_running_stats()
    
    def update_stats(self):
        if self.training and self.track_running_stats:
            self.bn.num_batches_tracked += 1
            eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
            self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
            var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
            self.bn.running_var  = self.bn.running_var  * (1-eaf) + var  * eaf
            self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_parameters(self):
        self.bn.reset_parameters()
        
    def forward(self, input):
        self.bn._check_input_dim(input)
        if self.track_running_stats:
            if self.iterations is None:
                self.running_mean   = self.bn.weight.new_zeros(self.num_features)
                self.running_square = self.bn.weight.new_zeros(self.num_features)
                self.iterations   = 0
            self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
            self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
            self.iterations += input.size(0)
        return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var, 
            False, 0., self.bn.eps, torch.backends.cudnn.enabled)   

The use is to pass the BatchNorm class we want and then the regular params. For instance:

tst = AccumulateBatchNorm(nn.BatchNorm2d, 64)

Then we can make the batches pass like a regular module, until we want to update the stats, which is done with tst.update_states() (the other functions are just there to make this compatible with other BatchNorm modules):

y = tst(x1)
y = tst(x2)
y = tst(x3)
tst.update_stats()

The result is an update as if we had passed the batch torch.cat([x1,x2,x3],0)

So @kcturgutlu, the Callback would need to be update to scan through the model and call this function on every AccumulateBatchNorm layer when you do a step. If you can experiment to check it yields the same results (when using a model with AccumulateBatchNorm with gradient accumulation vs a model with BatchNorm and no gradient accumulation), we can go on the next step and try to optimize that thing.

6 Likes

Thanks! I will experiment with this. I can confirm that the problem is indeed batchnorm, vgg without BN gives identical results with a diff of -+0.01. I believe this diff probably comes from floating point operations like summing many batches.

2 Likes

If, along the way, you come on some minimal example that gives worse results due to BN, we can experiment on it to see how to best fix it (change in momentum, or some other way BN works).

Sure, this is the with BN notebook I am currently working on: https://github.com/KeremTurgutlu/whale/blob/master/experimental/Every%20nth%20Batch%20Accumulate%20Update%20-%20BatchNorm.ipynb

In the more momentum section, note that pytorch has the opposite convention as the momentum in SGD. More momentum is setting the momentum in BN to 0.05 or 0.01 (to make 0.95 or 0.99 for SGD conventions). Default is 0.1 (for 0.9 in SGD conventions).

2 Likes

Iā€™m still trying to understand how yā€™all made a connection between momentum and the proper functioning of batchnorm when accumulating gradients, I have much to learn.

Upon digging deeper into https://github.com/pudae/kaggle-hpa I found this on batchnorm and momentum in https://github.com/pudae/kaggle-hpa/blob/master/utils/swa.py
Maybe it will be helpful?

def moving_average(net1, net2, alpha=1):
  for param1, param2 in zip(net1.parameters(), net2.parameters()):
    param1.data *= (1.0 - alpha)
    param1.data += param2.data * alpha


def _check_bn(module, flag):
  if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
    flag[0] = True


def check_bn(model):
  flag = [False]
  model.apply(lambda module: _check_bn(module, flag))
  return flag[0]


def reset_bn(module):
  if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
    module.running_mean = torch.zeros_like(module.running_mean)
    module.running_var = torch.zeros_like(module.running_var)


def _get_momenta(module, momenta):
  if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
    momenta[module] = module.momentum


def _set_momenta(module, momenta):
  if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
    module.momentum = momenta[module]


def bn_update(loader, model):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.
        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return
    model.train()
    momenta = {}
    model.apply(reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    for input_dict in tqdm.tqdm(loader):
        input = input_dict['image'].cuda(async=True)
        input_var = torch.autograd.Variable(input)
        b = input_var.data.size(0)

        momentum = b / (n + b)
        for module in momenta.keys():
            module.momentum = momentum

        model(input_var)
        n += b

    model.apply(lambda module: _set_momenta(module, momenta))

Any ideas on how to wrap all the BN layers of the model within callback? We need to reconstruct the model by converting all BN to accumulate BN.

When we know if this works Iā€™ll deal with that part. For now, the idea is just to wrap existing BatchNorm layers in this module above. I think you can easily tweak it to take a BN layer and not a class.

The 1st winner of the same competition (@bestfitting) tried to use gradient accumulation but did not work for himā€¦

It seems he did not care about modifying BN when he used grad. accum., and this is why it did not work for him, but worked well with the 3rd winner (@pudae) when he was aware about this issue.

I think this BN is a very important factor in using gradient accumulation properly, which is almost always ignored since it is trickyā€¦

With the efforts of the great fastai developers and community, we have a big chance to have a leading edge using this library, especially when increasing image size could give a boost in accuracy even with the usual GPU memory size of 10-12GB.

3 Likes

I am interested to test your modified BN class too. I will try it on the pets notebook that showed difference in accuracy when I used grad. accumā€¦

Can you elaborate more on how to do it?

I thought the same way as @kcturgutlu , by reconstructing the model (changing each BN to call your modified class) even for just our current testā€¦

It seems you have a better idea :slight_smile:

You may hard code the change for each BN layer after constructing learn. Take vgg16_bn for example. I tried but got a backward error will look again.

1 Like

I didnā€™t get that error, but the error rate is worse when bs = 2 with 16 steps in comparison to bs=2 without accumulation, let alone bs=32ā€¦ I will try to find where is the issueā€¦

Here is how changed the BN layers:

for i, g in enumerate(learn.layer_groups):
    for k, l in enumerate(g):
        if isinstance(l, nn.modules.batchnorm.BatchNorm1d): 
            learn.layer_groups[i][k]=AccumulateBatchNorm(nn.modules.batchnorm.BatchNorm1d,num_features=l.num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        if isinstance(l, nn.modules.batchnorm.BatchNorm2d): 
            learn.layer_groups[i][k]=AccumulateBatchNorm(nn.modules.batchnorm.BatchNorm2d,num_features=l.num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)

I can find that the change is happened with checking the model by learn.layer_groups:

learn.layer_groups

[Sequential(
   (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (1): AccumulateBatchNorm(
     (bn): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
   )
.
.
.
.

But the change is not seen if I check it with learn.model:

learn.model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
.
.
.
.

Why the layer groups are not the same as in the learn.model?

1 Like

What I meant was to change the previous class to

class AccumulateBatchNorm(nn.Module):
    
    def __init__(self, bn):
        super().__init__()
        self.bn,self.num_features = bn,bn.num_features
        self.track_running_stats,self.momentum = bn.track_running_stats,bn.momentum
        self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_running_stats(self):
        self.running_mean,self.running_square,self.iterations = None,None,None
        self.bn.reset_running_stats()
    
    def update_stats(self):
        if self.training and self.track_running_stats:
            self.bn.num_batches_tracked += 1
            eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
            self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
            var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
            self.bn.running_var  = self.bn.running_var  * (1-eaf) + var  * eaf
            self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_parameters(self):
        self.bn.reset_parameters()
        
    def forward(self, input):
        self.bn._check_input_dim(input)
        if self.track_running_stats:
            if self.iterations is None:
                self.running_mean   = self.bn.weight.new_zeros(self.num_features)
                self.running_square = self.bn.weight.new_zeros(self.num_features)
                self.iterations   = 0
            self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
            self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
            self.iterations += input.size(0)
        return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var, 
            False, 0., self.bn.eps, torch.backends.cudnn.enabled) 

then replace existing batchnorm layers by AccumulateBatchNorm(bn).

You should do this in learn.model. learn.layer_groups is only there to put parameters in groups, and we donā€™t introduce any new trainable parameter.

3 Likes

@kcturgutlu
I am fiddling with your notebookā€¦

The function find_active_bn is missingā€¦
I suppose it is in data_utils , right?

Can you upload it to github?

So here is how I wrapped all BN layers.
Is there a more efficient way to find all BN layers in a model other than multiple nested for loops?

def change_all_BN(module):
    for i in range(5):
        atr = 'bn'+str(i)
        if hasattr(module, atr):
            setattr(module,atr,AccumulateBatchNorm(getattr(module,atr)))


def wrap_BN(model):
    for i in range(len(model)):
        for j in range(len(model[i])):
            if isinstance(model[i][j], bn_types):
                model[i][j] = AccumulateBatchNorm(model[i][j])
            elif model[i][j].__class__.__name__ == "Sequential":
                for k in range(len(model[i][j])):
                    if isinstance(model[i][j][k], bn_types):
                        model[i][j][k] = AccumulateBatchNorm(model[i][j][k])
                    elif model[i][j][k].__class__.__name__ == "BasicBlock":
                        change_all_BN(model[i][j][k])
                        if hasattr(model[i][j][k],'downsample'):
                            if model[i][j][k].downsample is not None:
                                for l in range(len(model[i][j][k].downsample)):
                                     if isinstance(model[i][j][k].downsample[l], bn_types):
                                        model[i][j][k].downsample[l] = AccumulateBatchNorm(model[i][j][k].downsample[l])
                               

So with the BN layers wrapped, GPU memory keeps increasing until COM error. I suspect that the wrapped BN parameters are kept in and accumulated for all forward pass iterations in the GPU.
I will look further into it later.

Here is my code:

pat = re.compile(r'/([^/]+)_\d+.jpg$')
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=BS
                                  ).normalize(imagenet_stats)
def get_learner():
    turn_on_accumulation()
    learn = create_cnn(data=data, arch=models.resnet34, metrics=error_rate,
                       callback_fns=[partial(AccumulateStep, n_step=N_STEP)])
    learn.loss_func = CrossEntropyFlat(reduction="sum")
    return learn

learn = get_learner() 
wrap_BN(learn.model)
learn.fit_one_cycle(1)

You can use the new callback in the master which doesnā€™t require turn_on_.... Also I tried several things to make it work but itā€™s either backward error or pickle error. I am updating the repo, the link is already shared here so not sharing again.

I am also trying to understand torch.batch_norm code, which only takes running_mean and running_var but not the batch mean and var. I guess that means batch mean and var are calculated inside that function if training=True otherwise running stats are used. But this is probably not something we want. We want to normalize each sample with the batch stats, but since we donā€™t know the upcoming samples in accumulation and canā€™t temporarily hold them due to memory issues there might be couple of solutions. Either to do 2 x each epoch to calculate batch stats as itā€™s done in Kaggle competition shared here or we can accumulate batch stats (batch mean var) to be used for the current batch (this would be an approximation as we come closer to the end of batch). I tried implementing the latter one in my repo but there are some issues to be fixed. Appreciate everyoneā€™s help :slight_smile:

2 Likes

So I conducted some experiments as following:

model = vgg16_bn (chosen since itā€™s sequential - easy to manipulate)
data = MNIST_SAMPLE

Experiment Results

  1. No Accumulation
    batch_size = 64
    acc = 0.94

  2. Naive Accumulation
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.49

  3. Accumulation + BnFreeze
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.59

  4. Increase BN Momentum (on current batch stat)
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.49

  5. Decrease BN Momentum (on current batch stat)
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.57

  6. Replace BN with Instance Norm
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.57

  7. Replace BN with Group Norm
    num_groups=4
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.98

  8. ResNet18 + Replace BN with Group Norm
    num_groups=4
    effective_batch_size = 64
    step = 32
    bs = 2
    acc = 0.99

  9. ResNet18 + Replace BN with Group Norm no Accumulation
    num_groups=4
    bs = 2
    acc = 0.99

GroupNorm seems to work pretty good. But we canā€™t be sure without trying it out on resnet variants and on different datasets.

notebook : https://github.com/KeremTurgutlu/experimental/blob/master/Accumulating_Batchnorm.ipynb

Here is the group norm paper: https://arxiv.org/abs/1803.08494

@hwasiti maybe you can do the same experiments this time converting all bn layers to group norm like I did in my notebook

Edit

3 Likes

Please note that v1.0.47 will have a breaking change that affects this callback (see announcement in the developer chat). To skip the step and the grad zeroing, just return:

return {'skip_step': True, 'skip_zero': True}

wherever is more convenient (probably in on_backward_end)

1 Like