Is lr.find doing something similar to

It seems lr.find would have to do a full back & forward pass to find the gradients and test the effect of different learning rates on the loss?

I understand its usefulness and how to use its results, but I don’t get what it’s doing that’s different to doing single epoch passes in a loop while incrementing lr to find the optimal? It seems like a chicken and egg situation?


Hi !

If you take a look at the source code of lr_find, you have:

def lr_find(self, start_lr=1e-5, end_lr=10, wds=None, linear=False, **kwargs):'tmp')
        layer_opt = self.get_layer_opt(start_lr, wds)
        self.sched = LR_Finder(layer_opt, len(, end_lr, linear=linear)
        self.fit_gen(self.model,, layer_opt, 1, **kwargs)

What is important to notice are the save and load lines.

When you run lr_find, the stopping condition is self.stop_dv and (math.isnan(loss) or loss>*4), i.e. when your loss is NaN or 4x bigger than your best loss value. So basically, you check for the moment when your learning rate is too high and your network weights start diverging.

The problem is that, when it happens, it is almost impossible for your network to repair the damage done to the weights and to come back to a converging behavior, reason why we have the save/load in the lr_finder.

So you certainly do not want to have a learning rate test while you are training and this is the reason both need to be separate.

Hope that helps :slight_smile:

Ah, thanks Nathan, I get the explanation, but I get a different version of the source code for lr_find when I go through doc(lr_find):

def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):
    "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges."
    start_lr = learn.lr_range(start_lr)
    start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr
    end_lr = learn.lr_range(end_lr)
    end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr
    cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
    epochs = int(np.ceil(num_it/len(, start_lr, callbacks=[cb], wd=wd)

Oh yeah, I probably run an older version. Then go to the LRFinder callback, you will see the save at the on_train_begin and the load at on_train_end so the idea is still the same.

Thank you, found it.