can someone tell me how in the below code the loss.backward() isnt called on the validation data? I dont see the loss.backward() in any conditional statements so how come it doesnt run when validating.
def one_batch(xb, yb, cb):
if not cb.begin_batch(xb,yb): return
loss = cb.learn.loss_func(cb.learn.model(xb), yb)
if not cb.after_loss(loss): return
loss.backward()
if cb.after_backward(): cb.learn.opt.step()
if cb.after_step(): cb.learn.opt.zero_grad()
def all_batches(dl, cb):
for xb,yb in dl:
one_batch(xb, yb, cb)
if cb.do_stop(): return
def fit(epochs, learn, cb):
if not cb.begin_fit(learn): return
for epoch in range(epochs):
if not cb.begin_epoch(epoch): continue
all_batches(learn.data.train_dl, cb)
if cb.begin_validate():
with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
if cb.do_stop() or not cb.after_epoch(): break
cb.after_fit()