Adding Weight Decay To BasicOptimizer

I just finished 08_collab, and I’m trying to take the ideas presented in FastAI and pull them out into just straight pytorch-isms to understand them better.

Right now I’m adding in weight decay.

From the notebook:

parameters.grad += wd * 2 * parameters

so, I think that looks like this in the below code:

    def step(self, *args, **kwargs):
        for p in self.params:
            p.grad.data += self.wd * p.data
            p.data -= p.grad.data * self.lr

where “wd * 2” is just a constant that is “wd”

Can someone spot check this and confirm it is correct? I mean…it runs…hard to tell if it is working as intended.

class BasicOptim( torch.optim.Optimizer ):
    def __init__(self,params, lr=0.001, wd=0.1):
        defaults = dict(lr=lr, momentum=0,dampening=0,weight_decay=wd)
        super(BasicOptim, self).__init__(params, defaults)

        self.params = list(params)
        self.lr = lr
        self.wd = wd

    def step(self, *args, **kwargs):
        for p in self.params:
            p.grad.data += self.wd * p.data
            p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None