Running batch norm tweaks

Continuing a discussion from Lesson 11 thread… The goal is to speed up running batchnorm (rbn) by only recalculating statistics occasionally. Something like this (although this isn’t working):

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter  += 1
        
        if self.batch > 1000 and self.iter%2: return
        
        self.sums.detach_()
        self.sqrs.detach_()

        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)

        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(self.batch-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.means = self.sums/self.count
        self.varns = (self.sqrs/self.count).sub_(self.means.pow(2))
        if self.batch < 20: self.varns.clamp_min_(0.01)

    def forward(self, x):
        if self.training: self.update_stats(x)
        factor = self.mults / (self.varns+self.eps).sqrt()
        offset = self.adds - self.means*factor
        return x*factor + offset
1 Like

That’s easy:

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter  += 1
        
        if self.batch > 1000 and not self.iter%9: return
        
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)

        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(self.batch-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.means = self.sums/self.count
        self.varns = (self.sqrs/self.count).sub_(self.means.pow(2))
        if self.batch < 20: self.varns.clamp_min_(0.01)

        self.means.detach_() # added!
        self.varns.detach_() # added!
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        if self.iter < 2:
            l = "mults adds sums sqrs count means varns".split()
            print("not leaf ", list(filter(lambda x: not getattr(self,x).is_leaf, l)))
            print("want grad", list(filter(lambda x: getattr(self,x).requires_grad, l)))
        factor = self.mults / (self.varns+self.eps).sqrt()
        offset = self.adds - self.means*factor
        return x*factor + offset
    

So it runs, but skipping leads to loss==nan even if I change it to every 9th, instead of 2nd iter.

Note, that if you don’t skip it works fine (i.e. detach calls are not the cause)

Right but it’s not easy, because it doesn’t actually work. i.e. it doesn’t give an error, but, as you point out, it doesn’t usefully train. (Which is the same as what the code I showed at the top does.)

So the challenge here isn’t really a coding one, but a conceptual DL one - how do we do a batch norm that uses running statistics during training, and doesn’t calculate them more than mathematically necessary.

Yes that’s exactly what I saw. Accuracy with bs=32 went from 97%->90% IIRC.

It was my intent to have the next iteration of calcs be part of the graph - I did it this way because I was trying to avoid having a continuously growing history in the graph. Which I think is working AFAICT.

But it’s certainly true that I don’t fully understand the details of all this, and I’m sure there’s things that need to be fixed to make this work with the “occasional skipping”.

Gosh that sounds odd. Would love to learn more about this!

involving which variables specifically? or all of them?

So RunningBatchNorm is then relying on calculating weighted averages plus letting the network further tweak them via backprop - i.e. it is so by design?

And most likely if you remove these from the beginning of update_stats:

        self.sums.detach_()
        self.sqrs.detach_()

the outcome would be exactly the same, since as I mentioned in the long reply they get re-attached as soon you calculate them (due to x being in the graph), so the code above doesn’t contribute anything.

But it’s certainly true that I don’t fully understand the details of all this, and I’m sure there’s things that need to be fixed to make this work with the “occasional skipping”.

I think once you sort out which vars should be learnable and which not, then you could force the optional calculations to not involve anything on the graph, and then there will be no problem skipping those at will. Does it make sense?

I just don’t know which of them you meant to be learnable, so can’t code for it until this is known.

for example, self.count is not on the graph, so you can verify that you can skip calculating it at will:

        mom1 = self.s.new_tensor(1 - (1-self.mom)/math.sqrt(self.accumulated_bs-1))
        if self.bc % 2:
            self.count.lerp_(self.c , mom1)

gives no error.

@t-v @stas BTW the algorithm I had in mind is a little different to what you implemented - I think it can be faster and skip basically all the compute. I’ve created a new thread to discuss this:

To make this compare more precise for those reading along: what happens is that PyTorch has freed the graph of the previous batch in that batch’s backward and now finds that it has a node (from the previous batch) without the graph connections it would need to propagate the gradient.

It’ll all become more clear once Jeremy decides to do “even more impractical deep learning for coders” where he also re-implements autograd. :wink:

Just to clarify - all it can tweak are the model parameters. So we’re not changing what’s learnable when we detach something that’s not a parameter, only what grad history is available to our parameters.

My understanding (which might be wrong) is that detach removes the grad history. So whilst we are indeed immediately bringing these vars back in to the grad calculation, we’re doing so with a truncated history. This is necessary (AFAIK) to avoid the grad history getting too long to be workable (as we have to do in ULMFiT RNNs, for instance).

Figuring out those details is a key question in this mini-project! :slight_smile: The goal is: bn using running statistics that runs fast. The question is - how to do that in practice? I’m pretty sure it’s mathematically a totally reasonable thing to do, but I think that getting the details right to make it work in practice is an interesting exercise.

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter  += 1

        if self.batch > 1000 and self.iter%2:
            # we want to re-use use means and stds, but
            # if they're old, they don't depend on things that are
            # still around to propagate gradients to
            self.means.detach_()
            self.stds.detach_()
            return
        
        self.sums.detach_()
        self.sqrs.detach_()

        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)

        mom1 = s.new_tensor(1 - (1-self.mom)**(bs/32)) # edited.
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.means = self.sums/self.count
        self.stds = ((self.sqrs/self.count).sub_(self.means.pow(2)) + self.eps).sqrt()
        if self.batch < 20: self.stds = self.stds.clamp_min(0.1)

    def forward(self, x):
        if self.training: self.update_stats(x)
        factor = self.mults / self.stds
        offset = self.adds - self.means*factor
        return x*factor + offset

(The change from self.varns to self.stds is of course unrelated and just me reducing the number of ops in the shorcut path. :slight_smile: )

@t-v, it doesn’t train. With yours getting:

learn,run = get_learn_run(nfs, data, 0.4, conv_rbn, cbs=cbfs)
%time run.fit(1, learn)
train: [nan, tensor(0.1166, device='cuda:0')]
valid: [nan, tensor(0.0991, device='cuda:0')]
CPU times: user 6.3 s, sys: 54 ms, total: 6.35 s
Wall time: 6.39 s

I meant to say it was easy to “fix” the error, now that I have the debug tool to tell me what’s on graph and what not.

So the challenge here isn’t really a coding one, but a conceptual DL one - how do we do a batch norm that uses running statistics during training, and doesn’t calculate them more than mathematically necessary.

In addition to sorting out the conceptual part we need to get a real understanding of pytorch computational graph mechanics. To me the graph is still a black-box. I want to be able to see exactly what happens in each iteration, so that I could see the graph and the history. Then we understand how the moving parts work. Then we will know how to bolt the conceptual part on to the pytorch machinery.

@t-v, what debug tools can we use?

These were helpful so far: requires_grad, is_leaf
Any other more advanced tools? e.g. show me the graph?

What I have been trying to ask for quite a few messages by now - why do you want those accumulating variables sums and sqrs to have grads in first place? Don’t set the grads in first place and then you won’t need to worry about the history.

Because without it I don’t get good accuracy. Without that information in the gradients, the model doesn’t have as much information to know how to update its weights correctly.

If you can find a way to get as good accuracy without those grads, that would be great! :slight_smile:

So what happens is that the parameters blow up at small batch sizes. But my current guess is that that’s more the interplay between learning rate, batch size, how many stats to gather etc. that’s not yet working out here.

I’m trying this graph visualizer https://github.com/waleedka/hiddenlayer - but it fails to build a graph for this BN, failing with

ValueError: only one element tensors can be converted to Python scalars

torch.onnx land.

I was able to build part of it:

This is before I got a chance to calculate vars and means - fails to build if I enable that part. Also I had to unroll a few in place functions to be non-inplace:

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('means', torch.zeros(nf,1,1))
        self.register_buffer('varns', torch.ones(nf,1,1))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter  += 1
        
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)

        div = self.batch-1 if self.batch>1 else 1
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(div))
        
        self.sums  = self.sums .lerp(s , mom1)
        self.sqrs  = self.sqrs .lerp(ss, mom1)
        self.count = self.count.lerp(c , mom1)

        return
        self.means = self.sums.div(self.count)

        self.varns = (self.sqrs/self.count).sub(self.means.pow(2))
        if self.batch < 20: self.varns.clamp_min_(0.01)
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        factor = self.mults / (self.varns+self.eps).sqrt()
        offset = self.adds - self.means*factor
        return x*factor + offset
learn,run = get_learn_run(nfs, data, 0.4, conv_rbn, cbs=cbfs)
learn.model = learn.model.cuda()
import hiddenlayer as hl
g = hl.build_graph(learn.model, torch.randn([2, 1, 28, 28]).cuda())
g
#g.save("bn.png", format="png")

I can’t figure out why it stumbles upon means update. Well, it looks like an early alpha tool.

2 Likes

It doesn’t seem to be. Even v low LR and very frequently stats updates have the same issue.

I’ve been using this in the 07_batchnorm notebook with batch size 8 and a learning rate of 0.1 and it seems to converge - I get 97.8%, not great, but not NaN/10%. It’s awfully slow at 2 min 42 s.
To be sure, here is my current code:

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.batch = 0
        self.iter  = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.batch += bs
        self.iter  += 1

        if self.batch > 1000 and self.iter%2==0:
            # we want to re-use use means and stds, but
            # if they're old, they don't depend on things that has
            # gradients
            self.means.detach_()
            self.stds.detach_()
            return
        
        self.sums.detach_()
        self.sqrs.detach_()

        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)

        mom1 = s.new_tensor(1 - (1-self.mom)**(bs/32))
        if self.batch == bs:
            print ("mom1", mom1, "c", c)
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.means = self.sums/self.count
        self.stds = ((self.sqrs/self.count).sub_(self.means * self.means).clamp_min_(0) + self.eps).sqrt()
        if (torch.isnan(self.stds).any() or torch.isnan(self.means).any()):
            raise Exception("found NaN")
        if self.batch < 20: self.stds = self.stds.clamp_min(0.1)

    def forward(self, x):
        if self.training: self.update_stats(x)
        factor = self.mults / self.stds
        offset = self.adds - self.means*factor
        y =  x*factor + offset
        if (torch.isnan(y).any()):
            raise Exception("found NaN")
        return y

Do your parameters explode, too, before you get NaNs?

1 Like

Please note a new reply that ended up before your replies, because of the ridiculous rule of discourse setting that one can’t post more than 3 consecutive replies so I had to wait till one of you replied and only then it allowed me to post. And of course, being a “smart” software, it inserted my post in the wrong place :frowning:

Anyways, please have a look:

1 Like

OK, I found another visualizer that works (it seems to be that the first one I mentioned earlier has problems with JIT):

# pip install torchviz
from torchviz import make_dot
model = learn.model.cuda()

x = torch.randn([1, 1, 28, 28]).cuda()
g = make_dot(model(x), params=dict(model.named_parameters()))
g.format = 'png'
g.render('bn')
g

And if I run it on RunningBatchNorm with any detach, the first time we get:

So now we should be able to see what’s on the tree, and if you re-run the code, seeing where graph leaks, etc. e.g. if I pass the input 2nd time the tree doubles in size! that’s expected since I run it on the version with no detach or without torch.no_grad()

3 Likes