Challenge for advanced students: Implement AdamW and SGDW

This paper shows how to make Adam much more reliable and accurate: http://arxiv.org/abs/1711.05101

It includes a link to an implementation in Lua Torch. Would anyone like to take up the challenge to port it to Pytorch? Here are the steps I think:

  1. Diff the original Lua Torch repo they forked from with their new version, to see exactly what they changed
  2. Make a similar change to Pytorch, and try to incorporate into fastai lib
  3. Test on same CIFAR10 dataset and training method as the paper, and replicate their result
  4. Profit!

If you have a try, keep us posted here! :slight_smile:

17 Likes

Hello @jeremy,

I have been chasing this, so far not been able to replicate the results. But here are some observations:

  • The authors have used one more regularization, the shake-shake reg, along with their approach. Hence reproducing them will be difficult (or time-consuming) in fastai.
  • The implementation differs slightly with the paper and my understanding. The authors have been super helpful answering my queries . Most of the difference is due to presenting a cleaner code. However there are still differences I don’t understand. Am following up with the authors.
  • I have also been experimenting with what we already have. Basically, given all things same, I wanted to compare SGDR and Adam with Restart. Here I am using Adam out of the box (of PyTorch). Results are comparable (better) so far on CIFAR10. This is a very simple test and I will run some elaborate ones (see fig below). It will be interesting to compare these on bigger sets though. Probably this will be a good approach to compare AdamWR as well.


(Edit: Ignore the above result. The weight_decay was not properly set))
*I have applied a weight_decay of 0.025 to both, since I was to compare it with AdamWR, which is not yet ready.

I have some things coded up for the paper but there are some discrepancies I am sure, which I am ironing out.

Anyone else working on this and had any luck?

–

5 Likes

This is very encouraging! You’re right, dealing with shake-shake is rather orthogonal to this - what I really meant when I said we should “replicate” the results, is that we should see if we can check that AdamW and SGDW beat Adam and SGD.

It sounds like you might be there already! Can you show your code? What would be the minimal thing we’d add to fastai to support this?

So glad you looked into this - I was starting to think no-one had taken up the challenge… :slight_smile:

Sure! Will do some cleanup and post in a few hours.

The delta is actually very small (the theory was deep though for me :slight_smile: ). We already have the Restart logic in place, so that makes the change itself very minimal in terms of lines of code.

–

Also something to deal with weight decay?

Yes @jeremy , the change involves the 3 aspects mentioned in the paper:

  1. Decoupling weight decay from the gradient-based update (file: fastai/adamw.py)
  2. Normalizing the values of weight decay (file: fastai/model.py and fastai/learner.py)
  3. Adam with warm restarts and normalized weight decay (restarts already available in fastai)

I have attached two images below, one ran on cifar10 and the other on cats-and-dogs. Results are positive.

Approach to testing: Everything except the Optimizer and initial LR was kept same between the runs. I used a very simple resnet34 model and trained only the FC layers. But it was same for all optimizers. I compared SGDM and Adam to AdamW. All were weight decayed by a factor of 0.025 as mentioned in the paper. Only AdamW does (1) and (2) as per paper. Restart was applied to all.

Files: This is my github repo. This is the diff.

(Please ignore the coding style for now. Will be refactored.)

  • fastai/adamw.py: I copied the adam.py file from PyTorch and placed it here. Renamed it as AdamW. Implemented the first aspect here (decoupling weight decay)
  • fastai/learner.py and model.py: The decaying of the weight takes place here. It kicks in only if you use the AdamW optimizer. Most of the code is about taking the required variables to the Stepper class. The Stepper does the decaying.

Note that adam.py was taken from PyTorch 0.2.0, which is the version I am on. The file has been refactored since then in master.
Also note that the authors have also implemented SGDMW, which I am putting in now.

Also

  • courses/dl1/adam-experiment-cats-dogs.ipynb and adam-experiment-cifar10.ipynb: I ran the experiments here. They are exactly identical notebooks, only data source differs. I created two so that I could run them in parallel.

The notebooks can be viewed here and here.

Right now I am:
a) Once more checking the implementation
b) Will run with more variations like differential learning

It will be great if somebody independently reviews the change or atleast validates these results on their machines :slight_smile:

In the notebook you just need to change the PATH to point to your data. You can play with different models and architectures.

Cats and Dogs (loss vs iterations, all with restarts):

CIFAR10 (loss vs iterations, all with restarts):

9 Likes

I see you’ve gotten to SGDW now too :slight_smile:

I looked at your code. Very impressed at how you’ve managed to navigate through what is almost certainly the most complex (and, perhaps, least well structured!) part of the library!

As you say, your changes need plenty of refactoring. Let me know if you need help with anything. I wonder if you can implement nearly all of this with callbacks. (As you’ll have noticed, SGDR is currently handled with callbacks.)

If we need to make changes to pytorch’s Adam class to allow us to extend it, we should send them a PR with those changes and then use those extensions, rather than copy and modify their Adam class. I’m sure they’ll be interested in merging AdamW into pytorch anyway, so we should endeavor to do it in a way that doesn’t lead to duplicate code with stuff in pytorch.

Great work! :smiley:

6 Likes

Yes, @jeremy, I tried with SGDW too and the results are as expected (trend matches with the paper) :slight_smile: Loss curves in image below.

I was trying to experiment with larger models to see how it generalizes. I haven’t run long runs yet but the results are as expected (i.e. better).

Yes I noticed you do your magic with callbacks and I was studying how to make the required variables reach the new callback :slight_smile:

Just checked, there are a couple of threads already on SGDR and AdamW/SGDW. So people are already working on it :slight_smile: Anyway I have asked the AdamW folks if they need any assistance. I think I can help them with testing/review to push their updates to master faster.

A few learnings while working on this:

  • PyTorch source code is so similar to Torch’s (engineers try to put their own spin otherwise). For e.g. the adam.lua of Torch and adam.py of PyTorch have exactly the same sequence of actions. The implementers have kept PyTorch true to it’s roots.
  • I now have a better appreciation for LR and WDS. Where you place the decimal point and how you decay them has such a profound influence on convergence.
  • I started using Mendeley extensively! I came to know about it from your part2 video, and now i am an addict :slight_smile: It should be called Kindle for Researchers

Thanks for motivating me to work on this! Learnt a lot :slight_smile:

CIFAR10 (loss vs iterations, all with restarts):

11 Likes

Your are my hero @anandsaha! :slight_smile:

The wds you mention that is what we use for differential learning rates, right? I tried to expand the acronym, and I think it could go something like Weight Decay Schedule or maybe rather Scaling? But I think I might be completely lost here :slight_smile:

1 Like

I just tried :smile:

WDS is weight decays, Jeremy used the term in fit():

def fit(self, lrs, n_cycle, wds=None, **kwargs):

It is the regularization strength. It can be a scalar or vector just like LR.

They are passed to the optimizer in use through the weight_decay term just like lr, for e.g. :

class torch.optim.SGD(params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)
class torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

This paper innovates on how this parameter is applied and decayed with time.

–

9 Likes

@jeremy I implemented the weight decay callback, here is the diff.

User needs to pass the flag use_wd_schedule=True to fit() along with the AdamW or SGDW optimizer.

An afterthought was that we can deduce if the callback needs to be added to the callbacks list if the optimizer is AdamW or SGDW. That I think is more intuitive, since these optimizers cannot do otherwise.

I also added a directory called optim under fastai and kept the altered adamw.py and sgdw.py there.

How do we proceed from here? I did the above just to understand fastai code, we can wait for these features to come through PyTorch, or we can take these changes into fastai after your review.

5 Likes

@anandsaha that’s so cool!

Have a look at how the regularizer we use in the language model is implemented. I’m wondering if we move the regularizers (including weight decay) out of the optimizers entirely, and instead update the weights directly in the training loop. i.e. more directly like the pseudo-code in the paper shows, rather than adding them to the loss at all. That would mean we wouldn’t have to change Adam or SGD to use the new approach! (i.e. we wouldn’t actually use the weight decay mechanisms built into pytorch at all, but instead would do the updates in the training loop in fastai’s fit().)

Does that make sense? This is getting more and more advanced - but also more and more awesome… :slight_smile:

3 Likes

Hi @jeremy, I tried your suggestion - it worked and the code is leaner now without the extra adamw.py and sgdw.py :slight_smile: This idea is good because now it can work with any optimizer out of the box (though the paper address just sgd and adam, no harm it trying with others if we get good results).

Now user can pass the flag use_wd_schedule=True to fit() and we will start decaying the weight as per the paper.

On on_batch_begin() I cache the old weights and keep the weight_decay calculated.
On on_batch_end() I apply the same (i.e. -weight_decay * old_weights) to the new weights for regularization.

I am cleaning up the code but the current diff is here.

Thanks!

5 Likes

That’s terrific @anandsaha! You’re really doing an amazing job. Let me know when you think the code is as clean as you can make it for now, and I’ll take a look and see if I have anything to add. I’d really like to try to merge something into fastai this week, so we can show off Adam in next week’s class. I’m hoping to show the AdamW version in class, since I really think that original version is just plain wrong!

2 Likes

Hi @jeremy, I have generated the pull request: https://github.com/fastai/fastai/pull/46/

–

8 Likes