It seems Pytorch already has a builtin class for Group Normalization:
https://pytorch.org/docs/stable/nn.html#groupnorm
Iāve made a first try around a BatchNorm
module that can aggregate results over a few batch and only updates itās running mean var when we tell it to (with the mean and var across all the batches it saw). Itās not at all optimized, and pretrained models would need to be converted to use that if it works, but we can see if it fixes the problem first.
Here is the class:
class AccumulateBatchNorm(nn.Module):
def __init__(self, bn_class, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super().__init__()
self.num_features = num_features
self.track_running_stats,self.momentum = track_running_stats,momentum
self.bn = bn_class(num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
self.running_mean,self.running_square,self.iterations = None,None,None
def reset_running_stats(self):
self.running_mean,self.running_square,self.iterations = None,None,None
self.bn.reset_running_stats()
def update_stats(self):
if self.training and self.track_running_stats:
self.bn.num_batches_tracked += 1
eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
self.bn.running_var = self.bn.running_var * (1-eaf) + var * eaf
self.running_mean,self.running_square,self.iterations = None,None,None
def reset_parameters(self):
self.bn.reset_parameters()
def forward(self, input):
self.bn._check_input_dim(input)
if self.track_running_stats:
if self.iterations is None:
self.running_mean = self.bn.weight.new_zeros(self.num_features)
self.running_square = self.bn.weight.new_zeros(self.num_features)
self.iterations = 0
self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
self.iterations += input.size(0)
return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var,
False, 0., self.bn.eps, torch.backends.cudnn.enabled)
The use is to pass the BatchNorm
class we want and then the regular params. For instance:
tst = AccumulateBatchNorm(nn.BatchNorm2d, 64)
Then we can make the batches pass like a regular module, until we want to update the stats, which is done with tst.update_states()
(the other functions are just there to make this compatible with other BatchNorm modules):
y = tst(x1)
y = tst(x2)
y = tst(x3)
tst.update_stats()
The result is an update as if we had passed the batch torch.cat([x1,x2,x3],0)
So @kcturgutlu, the Callback would need to be update to scan through the model and call this function on every AccumulateBatchNorm
layer when you do a step. If you can experiment to check it yields the same results (when using a model with AccumulateBatchNorm with gradient accumulation vs a model with BatchNorm and no gradient accumulation), we can go on the next step and try to optimize that thing.
Thanks! I will experiment with this. I can confirm that the problem is indeed batchnorm, vgg without BN gives identical results with a diff of -+0.01. I believe this diff probably comes from floating point operations like summing many batches.
If, along the way, you come on some minimal example that gives worse results due to BN, we can experiment on it to see how to best fix it (change in momentum, or some other way BN works).
Sure, this is the with BN notebook I am currently working on: https://github.com/KeremTurgutlu/whale/blob/master/experimental/Every%20nth%20Batch%20Accumulate%20Update%20-%20BatchNorm.ipynb
In the more momentum section, note that pytorch has the opposite convention as the momentum in SGD. More momentum is setting the momentum in BN to 0.05 or 0.01 (to make 0.95 or 0.99 for SGD conventions). Default is 0.1 (for 0.9 in SGD conventions).
Iām still trying to understand how yāall made a connection between momentum and the proper functioning of batchnorm when accumulating gradients, I have much to learn.
Upon digging deeper into https://github.com/pudae/kaggle-hpa I found this on batchnorm and momentum in https://github.com/pudae/kaggle-hpa/blob/master/utils/swa.py
Maybe it will be helpful?
def moving_average(net1, net2, alpha=1):
for param1, param2 in zip(net1.parameters(), net2.parameters()):
param1.data *= (1.0 - alpha)
param1.data += param2.data * alpha
def _check_bn(module, flag):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
flag[0] = True
def check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn(module, flag))
return flag[0]
def reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.zeros_like(module.running_var)
def _get_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
momenta[module] = module.momentum
def _set_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.momentum = momenta[module]
def bn_update(loader, model):
"""
BatchNorm buffers update (if any).
Performs 1 epochs to estimate buffers average using train dataset.
:param loader: train dataset loader for buffers average estimation.
:param model: model being update
:return: None
"""
if not check_bn(model):
return
model.train()
momenta = {}
model.apply(reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for input_dict in tqdm.tqdm(loader):
input = input_dict['image'].cuda(async=True)
input_var = torch.autograd.Variable(input)
b = input_var.data.size(0)
momentum = b / (n + b)
for module in momenta.keys():
module.momentum = momentum
model(input_var)
n += b
model.apply(lambda module: _set_momenta(module, momenta))
Any ideas on how to wrap all the BN layers of the model within callback? We need to reconstruct the model by converting all BN to accumulate BN.
When we know if this works Iāll deal with that part. For now, the idea is just to wrap existing BatchNorm layers in this module above. I think you can easily tweak it to take a BN layer and not a class.
The 1st winner of the same competition (@bestfitting) tried to use gradient accumulation but did not work for himā¦
It seems he did not care about modifying BN when he used grad. accum., and this is why it did not work for him, but worked well with the 3rd winner (@pudae) when he was aware about this issue.
I think this BN is a very important factor in using gradient accumulation properly, which is almost always ignored since it is trickyā¦
With the efforts of the great fastai developers and community, we have a big chance to have a leading edge using this library, especially when increasing image size could give a boost in accuracy even with the usual GPU memory size of 10-12GB.
I am interested to test your modified BN class too. I will try it on the pets notebook that showed difference in accuracy when I used grad. accumā¦
Can you elaborate more on how to do it?
I thought the same way as @kcturgutlu , by reconstructing the model (changing each BN to call your modified class) even for just our current testā¦
It seems you have a better idea
You may hard code the change for each BN layer after constructing learn
. Take vgg16_bn for example. I tried but got a backward error will look again.
I didnāt get that error, but the error rate is worse when bs = 2 with 16 steps in comparison to bs=2 without accumulation, let alone bs=32ā¦ I will try to find where is the issueā¦
Here is how changed the BN layers:
for i, g in enumerate(learn.layer_groups):
for k, l in enumerate(g):
if isinstance(l, nn.modules.batchnorm.BatchNorm1d):
learn.layer_groups[i][k]=AccumulateBatchNorm(nn.modules.batchnorm.BatchNorm1d,num_features=l.num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
if isinstance(l, nn.modules.batchnorm.BatchNorm2d):
learn.layer_groups[i][k]=AccumulateBatchNorm(nn.modules.batchnorm.BatchNorm2d,num_features=l.num_features, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
I can find that the change is happened with checking the model by learn.layer_groups
:
learn.layer_groups
[Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): AccumulateBatchNorm(
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
)
.
.
.
.
But the change is not seen if I check it with learn.model
:
learn.model
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
.
.
.
.
Why the layer groups are not the same as in the learn.model
?
What I meant was to change the previous class to
class AccumulateBatchNorm(nn.Module):
def __init__(self, bn):
super().__init__()
self.bn,self.num_features = bn,bn.num_features
self.track_running_stats,self.momentum = bn.track_running_stats,bn.momentum
self.running_mean,self.running_square,self.iterations = None,None,None
def reset_running_stats(self):
self.running_mean,self.running_square,self.iterations = None,None,None
self.bn.reset_running_stats()
def update_stats(self):
if self.training and self.track_running_stats:
self.bn.num_batches_tracked += 1
eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
self.bn.running_var = self.bn.running_var * (1-eaf) + var * eaf
self.running_mean,self.running_square,self.iterations = None,None,None
def reset_parameters(self):
self.bn.reset_parameters()
def forward(self, input):
self.bn._check_input_dim(input)
if self.track_running_stats:
if self.iterations is None:
self.running_mean = self.bn.weight.new_zeros(self.num_features)
self.running_square = self.bn.weight.new_zeros(self.num_features)
self.iterations = 0
self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
self.iterations += input.size(0)
return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var,
False, 0., self.bn.eps, torch.backends.cudnn.enabled)
then replace existing batchnorm layers by AccumulateBatchNorm(bn)
.
You should do this in learn.model
. learn.layer_groups
is only there to put parameters in groups, and we donāt introduce any new trainable parameter.
@kcturgutlu
I am fiddling with your notebookā¦
The function find_active_bn
is missingā¦
I suppose it is in data_utils
, right?
Can you upload it to github?
So here is how I wrapped all BN layers.
Is there a more efficient way to find all BN layers in a model other than multiple nested for loops?
def change_all_BN(module):
for i in range(5):
atr = 'bn'+str(i)
if hasattr(module, atr):
setattr(module,atr,AccumulateBatchNorm(getattr(module,atr)))
def wrap_BN(model):
for i in range(len(model)):
for j in range(len(model[i])):
if isinstance(model[i][j], bn_types):
model[i][j] = AccumulateBatchNorm(model[i][j])
elif model[i][j].__class__.__name__ == "Sequential":
for k in range(len(model[i][j])):
if isinstance(model[i][j][k], bn_types):
model[i][j][k] = AccumulateBatchNorm(model[i][j][k])
elif model[i][j][k].__class__.__name__ == "BasicBlock":
change_all_BN(model[i][j][k])
if hasattr(model[i][j][k],'downsample'):
if model[i][j][k].downsample is not None:
for l in range(len(model[i][j][k].downsample)):
if isinstance(model[i][j][k].downsample[l], bn_types):
model[i][j][k].downsample[l] = AccumulateBatchNorm(model[i][j][k].downsample[l])
So with the BN layers wrapped, GPU memory keeps increasing until COM error. I suspect that the wrapped BN parameters are kept in and accumulated for all forward pass iterations in the GPU.
I will look further into it later.
Here is my code:
pat = re.compile(r'/([^/]+)_\d+.jpg$')
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=BS
).normalize(imagenet_stats)
def get_learner():
turn_on_accumulation()
learn = create_cnn(data=data, arch=models.resnet34, metrics=error_rate,
callback_fns=[partial(AccumulateStep, n_step=N_STEP)])
learn.loss_func = CrossEntropyFlat(reduction="sum")
return learn
learn = get_learner()
wrap_BN(learn.model)
learn.fit_one_cycle(1)
You can use the new callback in the master which doesnāt require turn_on_...
. Also I tried several things to make it work but itās either backward error or pickle error. I am updating the repo, the link is already shared here so not sharing again.
I am also trying to understand torch.batch_norm
code, which only takes running_mean
and running_var
but not the batch mean and var. I guess that means batch mean and var are calculated inside that function if training=True
otherwise running stats are used. But this is probably not something we want. We want to normalize each sample with the batch stats, but since we donāt know the upcoming samples in accumulation and canāt temporarily hold them due to memory issues there might be couple of solutions. Either to do 2 x each epoch to calculate batch stats as itās done in Kaggle competition shared here or we can accumulate batch stats (batch mean var) to be used for the current batch (this would be an approximation as we come closer to the end of batch). I tried implementing the latter one in my repo but there are some issues to be fixed. Appreciate everyoneās help
So I conducted some experiments as following:
model = vgg16_bn (chosen since itās sequential - easy to manipulate)
data = MNIST_SAMPLE
Experiment Results
-
No Accumulation
batch_size = 64
acc = 0.94 -
Naive Accumulation
effective_batch_size = 64
step = 32
bs = 2
acc = 0.49 -
Accumulation + BnFreeze
effective_batch_size = 64
step = 32
bs = 2
acc = 0.59 -
Increase BN Momentum (on current batch stat)
effective_batch_size = 64
step = 32
bs = 2
acc = 0.49 -
Decrease BN Momentum (on current batch stat)
effective_batch_size = 64
step = 32
bs = 2
acc = 0.57 -
Replace BN with Instance Norm
effective_batch_size = 64
step = 32
bs = 2
acc = 0.57 -
Replace BN with Group Norm
num_groups=4
effective_batch_size = 64
step = 32
bs = 2
acc = 0.98 -
ResNet18 + Replace BN with Group Norm
num_groups=4
effective_batch_size = 64
step = 32
bs = 2
acc = 0.99 -
ResNet18 + Replace BN with Group Norm no Accumulation
num_groups=4
bs = 2
acc = 0.99
GroupNorm seems to work pretty good. But we canāt be sure without trying it out on resnet variants and on different datasets.
notebook : https://github.com/KeremTurgutlu/experimental/blob/master/Accumulating_Batchnorm.ipynb
Here is the group norm paper: https://arxiv.org/abs/1803.08494
@hwasiti maybe you can do the same experiments this time converting all bn layers to group norm like I did in my notebook
Edit
Please note that v1.0.47 will have a breaking change that affects this callback (see announcement in the developer chat). To skip the step and the grad zeroing, just return:
return {'skip_step': True, 'skip_zero': True}
wherever is more convenient (probably in on_backward_end
)