Deepen understanding of callback by reimplementing the Learning Rate Finder

We have learned how to use dynamic learning rates in lesson 9 , and now we can reproduce LR_Find to deepen our understanding of callback.

The following is the implementation of the reference fastaiV1 source

class LRfinder(Callback):
    def __init__(self, start_lr=1e-5, end_lr=10):
        self.path = Path('models')
        self.sched_func =  sched_exp(start_lr, end_lr)  #Learning rate sched
        self.path.mkdir(parents=True, exist_ok=True)
        
    def begin_fit(self):
        self.save('temp')
        self.stop=False
        self.best_loss=0
        
    def begin_batch(self):
        if self.stop: return True
        if self.in_train:self.set_param() # Change lr on begin_batch
        
    def after_batch(self):
        if self.n_iter==1 or self.loss < self.best_loss: self.best_loss=self.loss
        if self.in_train:self.set_param()
        if self.loss > 4*self.best_loss: self.stop=True
        
    def set_param(self):
        for pg in self.opt.param_groups:
            pg['lr'] = self.sched_func(self.n_epochs/self.epochs)
        
    def after_fit(self):
        self.load('temp')
        
    def save(self,name):
        torch.save(self.model.state_dict(),self.path/f"{name}.pth")
        
    def load(self,name):
        self.model.load_state_dict(torch.load(self.path/f"{name}.pth"))
def lr_find(learn , start_lr=1e-3, end_lr=1.):
    cbfs = [Recorder, partial(AvgStatsCallback,accuracy),CudaCallback,
                partial(BatchTransfromXCallback, mnist_view),
                partial(LRfinder,start_lr=start_lr,end_lr=end_lr)]
    run = Runner(cb_funcs=cbfs)
    run.fit(3,learn)
    run.recorder.plot() #  Implement ax.plot(lrs, losses) in Recoder
#TEST
learn,run = get_learn_run(nfs, data, 1e-1, conv_layer, uniform=True,cbs=cbfs)
Lr_find(learn)

The result:
download
Welcome to correct the shortcomings and propose a better implementation.:kissing_heart:

4 Likes

In fastai, we use a moving average of the loss to denoise it a bit. That’s why you have those bumps :wink:

4 Likes