I’m trying to integrate SLS (stocchastic line search) with FastAI 1.05x and hitting a random issue where the model is not being updated…oddly one time I got it running properly and then went to setup for an 80 epoch run and it reverted to back to random results.
Anyway, SLS requires the ability to call the loss function and to control calling the loss.backwards function.
I made a closure for it in basic_train.py, but it’s clearly not right - can someone advise a better way to hook up a closure in FastAI?
def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:
"Calculate loss and metrics for a batch, call out to callbacks as necessary."
cb_handler = ifnone(cb_handler, CallbackHandler())
if not is_listy(xb): xb = [xb]
if not is_listy(yb): yb = [yb]
out = model(*xb)
out = cb_handler.on_loss_begin(out)
#if not loss_func: return to_detach(out), to_detach(yb[0])
loss = loss_func(out, *yb)
def closure():
out = model(*xb) #? needed
loss = loss_func(out, *yb)
print("closure called, loss = ",loss)
return loss
if opt is not None:
loss,skip_bwd = cb_handler.on_backward_begin(loss)
#if not skip_bwd: loss.backward() #TODO - SLS calls backward, not cb code
if not cb_handler.on_backward_end():
opt.step(closure)
if not cb_handler.on_step_end():
opt.zero_grad()
#recalc for metrics to update (?)
out = model(*xb)
loss = loss_func(out,*yb)
return loss.detach().cpu()
I had to leave some of the callbacks b/c FastAI would complain about missing items in the dict if I did not.
I also called the out and loss again one last time at the end in order to provide FastAI with the latest metrics since SLS could have called loss many times and loss.backwards many times within the loop there.
Anyway, if you have guidance on how to better setup a dynamic closure here that would be much appreciated. SLS looks like a huge improvement for optimizers but it (and PAL, Ali-G, etc) all require access to a closure so having this would be key to implementing any of these.
Thanks much!
Less