How to create a closure within fastAI for auto-tuning optimizer?

I’m trying to integrate SLS (stocchastic line search) with FastAI 1.05x and hitting a random issue where the model is not being updated…oddly one time I got it running properly and then went to setup for an 80 epoch run and it reverted to back to random results.

Anyway, SLS requires the ability to call the loss function and to control calling the loss.backwards function.
I made a closure for it in basic_train.py, but it’s clearly not right - can someone advise a better way to hook up a closure in FastAI?

def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
           cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:
"Calculate loss and metrics for a batch, call out to callbacks as necessary."
cb_handler = ifnone(cb_handler, CallbackHandler())
if not is_listy(xb): xb = [xb]
if not is_listy(yb): yb = [yb]
out = model(*xb)
out = cb_handler.on_loss_begin(out)

#if not loss_func: return to_detach(out), to_detach(yb[0])
loss = loss_func(out, *yb)

def closure():
    out = model(*xb) #? needed
    loss = loss_func(out, *yb)
    print("closure called, loss = ",loss)
    return loss

if opt is not None:
    loss,skip_bwd = cb_handler.on_backward_begin(loss)
    #if not skip_bwd:                     loss.backward() #TODO - SLS calls backward, not cb code
    if not cb_handler.on_backward_end(): 
        opt.step(closure)
        
    if not cb_handler.on_step_end():     
        opt.zero_grad()
        
#recalc for metrics to update (?)
out = model(*xb)
loss = loss_func(out,*yb)

return loss.detach().cpu()

I had to leave some of the callbacks b/c FastAI would complain about missing items in the dict if I did not.
I also called the out and loss again one last time at the end in order to provide FastAI with the latest metrics since SLS could have called loss many times and loss.backwards many times within the loop there.

Anyway, if you have guidance on how to better setup a dynamic closure here that would be much appreciated. SLS looks like a huge improvement for optimizers but it (and PAL, Ali-G, etc) all require access to a closure so having this would be key to implementing any of these.

Thanks much!
Less

2 Likes

The core issue seems to be that when SLS calls loss.backward() within the step function, this isn’t really being propagated properly…and thus with no actual loss.backwards call…it ends up cycling around and around with no progress.

I’ll see if having an optional loss.backwards call in the closure function itself gives it access…but any feedback on how to construct a true closure would be much appreciated!

1 Like

In one of the more bizarre bugs to date, the issue here seemed to boil down to the utilities for SLS using @contextlib…it seemed that things were getting gc’ed and thus no updates…but if you added print statements then it would work b/c the print statements prevented it from getting cleaned up.

Anyway, I removed the contextlib and the closure is working properly now.

I’m testing away with SLS and hope to have some news soon (worked really well on MNIST, testing on ImageWoof now).

1 Like

Oh is this the one you are trying to integrate with ranger?
Good job!

This was really helpful for helping me understand what closure is. How I’m currently understanding it, Pytorch wants you to pass in the loss to the optimizer step. Here is how I think it can be done with fastai’s fit system.

def _step_LBFGS(self): self.opt.step(lambda: self.loss)

and then replace the learner’s _step function like so (learn has to already be defined in this situation):

learn._step = partial(_step_LBFGS,learn)

and at this point, it works :slight_smile:

Hopefully this can be helpful to somebody else that was in my situation trying to figure out what LossClosure was and how to implement it. Thanks for the great starting point!

Here was my full implementation for fastai v2:

@delegates(torch.optim.LBFGS) #1
def LBFGS_opt(params, **kwargs): return OptimWrapper(torch.optim.LBFGS(params, **kwargs)) #2
def _step_LBFGS(self): self.opt.step(lambda: self.loss) #3

model = MySimpleModel()
loss_func = mse #F.mse_loss
𝜂 = 1
partial_learner = partialler(Learner,dls_normal, model, loss_func, cbs=[OptimizerInsight])
learn = partial_learner(LBFGS_opt, lr=𝜂)
learn._step = partial(_step_LBFGS,learn) #4
1 Like