What does LayerOptimizer function do?


Just reading the code and trying to make sense. Had a few doubts.

While calling the fit function :

def fit(self, lrs, n_cycle, wds=None, **kwargs):
        self.sched = None
        layer_opt = self.get_layer_opt(lrs, wds)
        self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)

get_layer_opt returns this:

LayerOptimizer(self.opt_fn, self.get_layer_groups(), lrs, wds)

Am I correct in assuming that get_layer_groups(self.precompute) returns all the weights and outputs of all the layers? (groups) except last layer for resnet34. This was calculated in the last step using the ConvLearner.pretrained function.

My main doubt is what does the LayerOptimizer function return? Is it calculating the learning rate?

This is the code for LayerOptimizer

Init signature: LayerOptimizer(opt_fn, layer_groups, lrs, wds=None)
class LayerOptimizer():
    def __init__(self, opt_fn, layer_groups, lrs, wds=None):
        if not isinstance(layer_groups, (list,tuple)): layer_groups=[layer_groups]
        if not isinstance(lrs, Iterable): lrs=[lrs]
        if len(lrs)==1: lrs=lrs*len(layer_groups)
        if wds is None: wds=0.
        if not isinstance(wds, Iterable): wds=[wds]
        if len(wds)==1: wds=wds*len(layer_groups)
        self.layer_groups,self.lrs,self.wds = layer_groups,lrs,wds
        self.opt = opt_fn(self.opt_params())

    def opt_params(self):
        params = list(zip(self.layer_groups,self.lrs,self.wds))
        return [opt_params(*p) for p in params]

    def lr(self): return self.lrs[-1]

    def set_lrs(self, lrs):
        set_lrs(self.opt, lrs)
File:           ~/fastai/courses/dl1/fastai/layer_optimizer.py
Type:           type


1 Like

No, it doesn’t. I simply returns the parameters (weights) of the model, not any outputs. Try calling it (or any other function you’re interested in) from the notebook to see what it returns.

LayerOptimizer create the actual nn.optim instance we use to optimize. It’s in the opt attribute.

1 Like

Hi Jeremy, I was trying to make sense of the set_lrs function. I looked at the source code and further got confused from the use of listify and zip_strict. Would you please give an explanation of how set_lrs works?