HELP understanding fastai callbacks

can someone tell me how in the below code the loss.backward() isnt called on the validation data? I dont see the loss.backward() in any conditional statements so how come it doesnt run when validating.

def one_batch(xb, yb, cb):
if not cb.begin_batch(xb,yb): return
loss = cb.learn.loss_func(cb.learn.model(xb), yb)
if not cb.after_loss(loss): return
if cb.after_backward(): cb.learn.opt.step()
if cb.after_step(): cb.learn.opt.zero_grad()

def all_batches(dl, cb):
for xb,yb in dl:
one_batch(xb, yb, cb)
if cb.do_stop(): return

def fit(epochs, learn, cb):
if not cb.begin_fit(learn): return
for epoch in range(epochs):
if not cb.begin_epoch(epoch): continue
all_batches(, cb)

    if cb.begin_validate():
        with torch.no_grad(): all_batches(, cb)
    if cb.do_stop() or not cb.after_epoch(): break

Because the validation data is called with β€œwith torch.no_grad():” so the no_grad() doesn’t allow any gradients to be computed even if they are called

1 Like

@bluesky314 but take a look at this simpler version of fit

def one_batch(xb, yb, learn):
loss = learn.loss_func(learn.model(xb), yb)


def all_batch(dl, learn):
for xb,yb in dl:
one_batch(xb,yb, learn)

def fit(epoch, learn):
learn.model.train = True
all_batch(, learn)

learn.model.train = False
with torch.no_grad(): all_batch(,learn)

when i tried to fit this i got the following error

RuntimeError Traceback (most recent call last)
----> 1 fit(1, learn)

in fit(epoch, learn)
16 learn.model.train = False
β€”> 17 with torch.no_grad(): all_batch(,learn)

in all_batch(dl, learn)
8 def all_batch(dl, learn):
9 for xb,yb in dl:
β€”> 10 one_batch(xb,yb, learn)
12 def fit(epoch, learn):

in one_batch(xb, yb, learn)
2 loss = learn.loss_func(learn.model(xb), yb)
----> 4 loss.backward()
5 learn.opt.step()
6 learn.opt.zero_grad()

D:\Program_files\lib\site-packages\torch\ in backward(self, gradient, retain_graph, create_graph)
105 products. Defaults to False.
106 β€œβ€"
–> 107 torch.autograd.backward(self, gradient, retain_graph, create_graph)
109 def register_hook(self, hook):

D:\Program_files\lib\site-packages\torch\ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
91 Variable._execution_engine.run_backward(
92 tensors, grad_tensors, retain_graph, create_graph,
β€”> 93 allow_unreachable=True) # allow_unreachable flag

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

im using torch.no_grad() as well, how come its calculating loss.backward()