Parameter scheduler from nb_09 with multiple schedules?

Hi there, I am hoping someone can help me clarify for me the following callback as well as my understanding of Jeremy’s Optimizer class

The callback:

class ParamScheduler(Callback):
    def __init__(self, pname, sched_func):
        self.pname,self.sched_func = pname,sched_func

    def set_param(self):
        for h in self.opt.hypers:
            h[self.pname] = self.sched_func(self.n_epochs/self.epochs)

    def begin_batch(self):
        if self.in_train: self.set_param()

The Optimizer:

class Optimizer():
    def __init__(self, params, steppers, **defaults):
        self.steppers = listify(steppers)
        maybe_update(self.steppers, defaults, get_defaults)
        # might be a generator
        self.param_groups = list(params)
        # ensure params is a list of lists
        if not isinstance(self.param_groups[0], list): self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for p in self.param_groups]

    def grad_params(self):
        return [(p,hyper) for pg,hyper in zip(self.param_groups,self.hypers)
            for p in pg if p.grad is not None]

    def zero_grad(self):
        for p,hyper in self.grad_params():

    def step(self):
        for p,hyper in self.grad_params(): compose(p, self.steppers, **hyper)

My understanding here is Optimizer.hypers is a list of dictionaries, one for each parameter group, containing values for hyperparameters like beta, learning rate, etc. for each parameter group. The Param Scheduler however loops over all of these groups and applies one scheduling function, so the only way I am seeing right now to do different schedules for each parameter group is to change the code so ParameterScheduler accepts a list of schedules (the same length as Optimizer.hypers) and apply them in that fashion. Or am I misunderstanding the code here ?

Lastly, I was curious about the use of keeping the reference to the optimizers parameters/parameter groups in lists/list of lists rather than generator as in nn.module.parameters(). I should probably check for myself but if anyone knows off the top of their head, will adding the line of code:

self.param_groups = [(x for x in pg) for pg in self.param_groups]
after checking param_groups is a list of lists (i) break any of the following code and (ii) save memory by making the groups into generators? Very much appreciate any comments/insights, thanks in advance!

Well spotted! We just implemented discriminative LR on the weekend, so you’ll find that’s exactly what we now have! :slight_smile: