Lesson 10 Discussion & Wiki (2019)

@jeremy, between BatchNorm and RunningBatchNorm with torch.no_grad() was removed - is it no longer needed because you instead used detach on the large vectors inside update_stats?

@t-v, what about 0D variables inside RunningBatchNorm.update_stats - shouldn’t those be detached too?

And is this still the case that one needs to detach if a variable is not a buffer or parameter but say just doing:

__init__:
self.counter = 0
forward:
self.counter +=1

I thought normal variables won’t attach themselves to the graph, unless they were created with requires_grad=True or are used in a calculation of a variable that is already part of a graph.

torch.no_grad() will not connect new variables to the graph, but it won’t do anything to x (which you probably had before).

So you’re suggesting list append and then a single cat, correct?

Yes. The “obvious” pattern is correct here.

This doesn’t work, since this is exactly the problem we are trying to solve - where mean is often nan because there is not enough data to calculate it on and my idea was to gather enough data to do it on.

mean should be non-nan once you have tensors with more 0 elements, but if your input sizes vary, you’d probably want sum and keep track of numel.
var will not be defined with bs * w * h = 1, but then you’re doing something fundamentally wrong, probably. non-centered moments should work just as mean does.
Maybe I don’t quite understand what you’re trying to do, though.

Best regards

Thomas

1 Like

With any bit of luck, they’re not requiring grad, so no detaching needed. :slight_smile:

1 Like

I think I’m a little bit lost, and putting some context back will help. So If I want to save a copy of x inside a layer, so that I could refer to it in later forward passes, like so:

forward:
with torch.no_grad(): l.append(x)

I can’t detach x or it’d mess up the original x, no? So do I need to clone x instead and set it not to require grad?

In other words, what’s the correct way to stash away some data flowing through the layer without affecting it? i.e. don’t mess up input and output in forward/backward.

var will not be defined with bs * w * h = 1, but then you’re doing something fundamentally wrong, probably. non-centered moments should work just as mean does.
Maybe I don’t quite understand what you’re trying to do, though.

We are trying to solve a problem where there is not enough RAM and a user uses bs=2 or bs=1. You can’t calculate var with bs=1. So I save that single input, do nothing in this forward pass, then concatenate it with a new bs=1 input from the following pass and then I might be able to calculate variance. (but more like needing at least 4-8 data points - so need to aggregate it to 4-8 mini-batches if bs=1 or 2. Is this helpful?

This doesn’t work, since this is exactly the problem we are trying to solve - where mean is often nan

oh, I see I didn’t write what I was intending. I meant to say variance instead of mean.

No. There are six or so cases

  • x.detach_() change tensor to not require grad --> You don’t want this.
  • x.detach() new tensor with same memory(!) and no requires_grad, unconnected to the graph. --> this is what you want if you save x (which you should not) or for using that in calculating mean/std.
  • x.clone() new tensor and new memory but grad-connected if x requires grad
  • x.clone().detach_() new tensor, new memory, no requires grad, unconnected to the graph
  • x.detach().requires_grad_() new tensor, same memory, requires grad, but not connected (i.e. leaf)
  • x.clone().detach_().requires_grad_() oh well, you’re bored by now.

no_grad might be odd to use here.

You cannot calculate the (unbiased, you would get a biased one) var of a single-element tensor. But usually you have h > 1 and w > 1, so that that isn’t a problem. Even for a single-element per channel tensor, you can track (x**2).mean(0, 2, 3).

To be honest, I’m skeptical of BN when you only have a few features, “traditional” BN is completely bogus with feature planes of 1 (because after normalizing, x, the input will be 0), running BN will be a bit better, but will it be good?

Best regards

Thomas

3 Likes

@jeremy, looking at the latest incarnation of RunningBatchNorm, why are we recalculating everything for inference? Here is a refactored version:

#export
class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom, self.eps = mom, eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds  = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('factor', tensor(0.))
        self.register_buffer('offset', tensor(0.))
        self.batch = 0
        
    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s    = x    .sum(dims, keepdim=True)
        ss   = (x*x).sum(dims, keepdim=True)
        c    = s.new_tensor(x.numel()/nc)
        mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))
        self.sums .lerp_(s , mom1)
        self.sqrs .lerp_(ss, mom1)
        self.count.lerp_(c , mom1)
        self.batch += bs
        means = self.sums/self.count
        vars = (self.sqrs/self.count).sub_(means*means)
        if bool(self.batch < 20): vars.clamp_min_(0.01)
        self.factor = self.mults / (vars+self.eps).sqrt()
        self.offset = self.adds - means*self.factor
        
    def forward(self, x):
        if self.training: self.update_stats(x)
        return x*self.factor + self.offset

The only thing I can’t figure out is how to get rid of the first 3 buffers - they no longer need to be saved in the model and can be normal vars, but if I replace them with normal vars I have the device issue CUDA vs. CPU, e.g. if I replace:

        #self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.sums = torch.zeros(1,nf,1,1)

I get:

---> 24         self.sums .lerp_(s , mom1)
     25         self.sqrs .lerp_(ss, mom1)
     26         self.count.lerp_(c , mom1)

RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for CPU_tensor_apply)

So I have to then do an explicit cuda() or to() when assigning a tensor in those vars, but I don’t know how to do it so that it’ll work transparently regardless of user’s setup. It seems that register_buffer does the right thing.

p.s. RunningBatchNorm uses a variable vars - which is a built-in function. so that’s probably not a good idea :wink:

Yeah that’s the other reason to use buffers.

1 Like

TIL. Will change it.

1 Like

But it sounds like we are then using it for the side-effect of it - why store in the model something that is a temp variable?

There must be a better way.

I didn’t think of it as a side-effect - it was the main reason I did it that way. But if there’s a way that avoids serializing unnecessary data, I’d be happy to switch to it. (But which doesn’t require significant extra complexity.)

Originally it was used out of necessity, since we wanted those vars to be stored in the model so that it could be used during inference. In the refactored version if you think it’s valid these are no longer needed.

I think that’s reasonable.

The thing I’d really like is to change that if self.training: to if self.training and (self.steps<100 or self.steps%4==0): , so that once things have stabilized it doesn’t recalc stats so often. Last time I tried to get this to work I had trouble figuring out the detach details. If anyone gets this working please let me know! :slight_smile:

Nope, originally I used it so it would be moved to CUDA automatically. I don’t know how to do it otherwise in a convenient way, either in the existing or the refactored version.

Although AFAICT it also needs to be stored still, since otherwise fine-tuning won’t work.

On a related note, I’ve also seen variable names like “input” used in functions. For example, the definition of nll here.

Yes it always feels odd to me when I do that, but I stay consistent with pytorch so use it in loss functions.

1 Like

Have you tried running nll and then do the actual input() call as it’s intended by python. If the latter breaks then it’s a bug in pytorch and fastai.

a quick test shows that it should break:

vars()
vars = 5
vars() # fails

edit: as @amanmadaan replies latter this is not a problem since it’s a local variable, so it’s ok.

A small clarification…

At 1:52:24, “The variance of a batch of one is infinite.” I think what’s meant here is that the variance calculates as zero, and you would be dividing by zero to normalize the batch to standard deviation 1.

I understand there’s a difference between population variance and sample variance, and that PyTorch var() returns NaN. But for this explanation what is pertinent is why the filter would be scaled to infinity.

I’ll experiment with it since I also want to understand that detach thing.

How would you recommend to “measure” any regression in such fine tuning? I know you usually recommend to keep the randomness, but this situation feels to me should call for a fixed seed at least in the initial steps so that any regressions can be immediately seen. Does it make sense?

Also I think your suggested 100 should be different depending on the bs, no? it’d be a very different measurement with bs=2 vs bs=512.

Right, it will fail if input is used in the function (or the same scope).

The following is legal:

def times5(input):
    print(int(input) * 5)

ip = input("Enter a number")
times5(ip)

which is how pytorch and fastai use the variable input.

The following is not:

input = input("Enter a number")
input
input("Enter another number") #will fail
1 Like

I just run things a few times to get a sense of how stable they are. Generally it’s pretty obvious when something breaks. If I have a fixed seed I find it hard to know if it’s working since I might have got lucky with that one seed.

Yes that’s better. In an earlier version I kept a counter called self.batch and did self.batch += bs. Then you could check self.batch<200 or similar.

1 Like