LAMB Optimizer

I wanted to start a topic to discuss the LAMB Optimizer we discussed and implemented in Lesson 11.

I’m interested in trying it out on a large distributed language model I’m working on so I’m going to try to port the code over to work with fastai v1.

Planning on tracking my progress and results in this thread.

The first thing I’m going to try is just using the exact code from the 09_optimizers.ipynb notebook and passing it to get_learner similar to how does it… but I’m pretty sure the more fastai way would be to use and/or subclass GeneralOptimizer so as not to duplicate a bunch of code.


No need to subclass - look carefully at the notebook and you’ll see that we’ve defined partial functions for each optimizer. That’s the thing you’d want to use.


I was looking at GeneralOptimizer and it didn’t look like it had the state variable we added in StatefulOptimizer.

I think I’m seeing that Pytorch’s Optimizer that GeneralOptimizer inherits from might already do that though.

Still trying to work my way through understanding how the things we implemented in class map over to fastai v1.

The Optimizer in the course is a bit different from the one in pytorch. I’d suggest sticking with it (and with StatefulOptimizer unless you’re happy doing quite a bit of rewriting. (I think our version is easier to use and understand too)

1 Like

I haven’t been able to get it to work quite yet, I ran into trouble passing the opt_func to language_model_learner:

AttributeErrorTraceback (most recent call last)
<ipython-input-17-2b09e9261ef7> in <module>
----> 1 learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.25, opt_func=lamb_opt)

/opt/conda/lib/python3.7/site-packages/fastai/text/ in language_model_learner(data, arch, config, drop_mult, pretrained, pretrained_fnames, **learn_kwargs)
    212         fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
    213         learn.load_pretrained(*fnames)
--> 214         learn.freeze()
    215     if pretrained_fnames is not None:
    216         fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]

/opt/conda/lib/python3.7/site-packages/fastai/ in freeze(self)
    217         "Freeze up to last layer group."
    218         assert(len(self.layer_groups)>1)
--> 219         self.freeze_to(-1)
    220         self.create_opt(

/opt/conda/lib/python3.7/site-packages/fastai/ in freeze_to(self, n)
    212                 if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
    213         for g in self.layer_groups[n:]: requires_grad(g, True)
--> 214         self.create_opt(
    216     def freeze(self)->None:

/opt/conda/lib/python3.7/site-packages/fastai/ in create_opt(self, lr, wd)
    198     def create_opt(self, lr:Floats, wd:Floats=0.)->None:
    199         "Create optimizer with `lr` learning rate and `wd` weight decay."
--> 200         self.opt = OptimWrapper.create(self.opt_func, lr, self.layer_groups, wd=wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
    202     def split(self, split_on:SplitFuncOrIdxList)->None:

/opt/conda/lib/python3.7/site-packages/fastai/ in create(cls, opt_func, lr, layer_groups, wd, true_wd, bn_wd)
     22         split_params = split_no_wd_params(layer_groups)
     23         opt = opt_func([{'params': p, 'lr':0} for p in split_params])
---> 24         opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)
     25,opt.opt_func = listify(lr, layer_groups),opt_func
     26         return opt

/opt/conda/lib/python3.7/site-packages/fastai/ in __init__(self, opt, wd, true_wd, bn_wd)
     11     def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):
     12         self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd
---> 13         self.opt_keys = list(self.opt.param_groups[0].keys())
     14         self.opt_keys.remove('params')
     15         self.read_defaults()

AttributeError: 'list' object has no attribute 'keys'

Going to try to trace through and find where the mismatch in expected format is tomorrow.

Ah yes if you want to use it with fastai v1 code you will indeed need to convert it to use the optim.Optimizer base class.

1 Like

Thanks for starting the thread. I am interested in implementing it in StyleGAN to see if it leads to a decrease in training time (good or bad idea?). I am trying to implement it in the same vein as the Adam tensorflow implementation but the code is just so freakin hard to follow.

Here is the tensorflow implementation of Adam:

I am trying to pin down exactly where the gradients are getting updated…I am guessing it is the _apply_dense and _resource_apply_dense functions but they call a mysterious method training_ops.apply_adam/training_ops.resource_apply_adam that is no where to be found (training_ops imports gen_training_ops which is no where to be found)?

The only ‘math’ I could find is happening in the _apply_sparse_shared method but I am not sure when this would be called and why it is treated differently.

Appreciate any help in pointing me in the right direction.

Has anyone tried using the lamb() defined in notebook 09 (the recent version) with success? I am getting dramatically inferior performance to adam when I am trying on PETS or FOOD dataset with out-of-the-box hyperparameters (scheduling LR and mom according to one-cycle with max_lr chosen by lr_find, same with adam). Any thoughts? I will try to read the paper and see if I am missing something obvious here and check back in later. Thanks.

I do see weight decay (wd) is 0 in the notebook: lamb_step._defaults = dict(eps=1e-6, wd=0.)

while the recommended value is 0.01 in the paper. Not sure if this is causing the discrepancy or its been adjusted elsewhere in the code.

Also confused about this step: * min(r1/r2,10), step)

Isn’t this essentially p = previous_p + step - lr * min(r1/r2,10)

Reading the formula (step 10) should it not be p = previous_p - lr * min(r1/r2,10) * step

So multiply step instead of subtract it?

I haven’t had a chance to get back to this yet but it’s still near the top of my list to experiment with. I’ll update with my results after.

1 Like

OK…I am probably missing something here but I am having even more questions about this as I dig deeper…

I am seeing 3 issues:

  1. The weight decay is set to 0 while it is 0.01 in the paper. Right now it appears we are zeroing out the whole previous weight part in the implementation.

  2. The L2 norm is the square root of the sum of the squares e.g… sqrt(a^2 + b^2 + c^2 + …). However in the implementation we are taking the mean of the squares before taking its square root which would divide the whole thing by an extra sqrt(n) term where n is the number of elements e.g. sqrt(a^2 + b^2 + c^2 + … / n). I suppose this is a moot point as this term will cancel out when we take the r1/r2 ratio but wanted to mention it.

  3. We are subtracting step in the last part rather than multiplying (see my previous comment).

So…shouldn’t it be like this?

def lamb_step(p, lr, mom, mom_damp, step, sqr_mom, sqr_damp, grad_avg, sqr_avg, eps, wd, **kwargs):
  debias1 = debias(mom,     mom_damp, step)
  debias2 = debias(sqr_mom, sqr_damp, step)
  r1 =
  step = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps) + wd*
  r2 = step.pow(2).sum().sqrt() * min(r1/r2,10) * step)
  return p

lamb_step._defaults = dict(eps=1e-6, wd=0.01)

Also…funny observation but looks like epsilon is inside the square root at least according to the formulas presented (see step 7 in notebook). :smiley: But weirdly enough I see it being used outside in implementations.

1 Like

For 1. anyone can have their own default. In every PyTorch optimizer, wd defaults to 0.0 so I followed that logic here.
2. is irrelevant in any case. Mathematically the L2 norm is the square root of the sum of squares, yes, but having the mean will be more numerically stable as we have less large numbers this way.
3. This is the same. In PyTorch, x.add_(l, b) does x += l*b.


Thanks for replying! OK that makes sense. I think the add function threw me off because it adds tensors but multiplies a scalar with a tensor before adding to the input which is how its setup here.

I have implemented LAMB in Tensorflow. Appreciate any feedback. :slight_smile:

this is the same as -= lr * min(r1/r2,10) * step * min(r1/r2,10) * step)

i have tried with : * r1/(r2+eps) * step)
it worked well on mnist if not better

Yes that should be better :slight_smile:

In the LAMB paper section 3.1 they talk about handling 0s in the trust ratio, but I don’t see that in the course notebook. Is something missing?

I found it quite hard to parse that section, but I found another implementation that says if r1 or r2 is 0, then r=1.

After much hacking I got StatefulOptimizer and lamb_func working with fastai v1. It’s really ugly right now but the results look promising.

I spent all day yesterday tweaking hyper-parameters on my language model and on the first run with LAMB I’m seeing a ~0.3% accuracy & validation loss improvement over my best results from yesterday (on 3 epochs, frozen).

Unfortunately it looks like I have a bug somewhere (probably somewhere in my param_groups hacks) because once I unfroze the losses and accuracy stopped improving much whereas with fastai v1 Adam they kept getting better.

The main things I had to do to get it working with fastai were:

  • Modify OptimWrapper to not wrap these new-style optimizers (I did this in a hacky way by detecting the presence of hypers on the optimizer)
  • Modified the constructor of Optimizer to account for OptimWrapper passing parameters in a different format to its initializer
  • Changed grad_params to not error out due to the different format of the parameters
  • Hacked into Optimizer.__setattr__ to propagate OneCycle's updates to the lr and mom properties through to hypers so they get passed to the stepper functions

The results are promising so I’m going to keep at it. I’m excited to see how things will look once I’m running bigger batch sizes on multiple GPUs since that’s where it’s really supposed to shine.

Edit: wd=0 worked even better than wd=0.01; picked up another 0.3% accuracy on the 3 frozen epochs (although with a bit higher validation loss).


I’m curious if LAMB made it into the fastai library. Did you submit a PR with your solution?