I am getting a bit confused with how the classes are working when dealing with Call Back Handler.
The code of the handler is and a sample callback is given by:
class CallbackHandler():
def __init__(self,cbs=None):
self.cbs = cbs if cbs else []
def begin_fit(self, learn):
self.learn,self.in_train = learn,True
learn.stop = False
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
def after_fit(self):
res = not self.in_train
for cb in self.cbs: res = res and cb.after_fit()
return res
def begin_epoch(self, epoch):
learn.model.train()
self.in_train=True
res = True
for cb in self.cbs: res = res and cb.begin_epoch(epoch)
return res
def begin_validate(self):
self.learn.model.eval()
self.in_train=False
res = True
for cb in self.cbs: res = res and cb.begin_validate()
return res
and
class TimeCheck(Callback):
def begin_fit(self, learn):
self.learn = learn
self.epoch_counter = 1
return True
def begin_epoch(self, epoch):
self.epoch=epoch
print(f'Epoch {self.epoch_counter} started at {time.strftime("%H:%M:%S", time.gmtime())}')
self.epoch_counter += 1
return True
What is āresā doing in : res = True for cb in self.cbs: res = res and cb.begin_fit(learn)
Or in
res = not self.in_train for cb in self.cbs: res = res and cb.after_fit()
I am unable to understand its function. āresā is used in all the callbacks and I just dont get what its for. Can someone clarify? It is not used in any individual callback just the handler.
Lets take begin_fit as example, but they are all the same. Lets assume there are 3 callbacks in cbs. The first one will return true, but the second one will return false. Third one doesnāt really matter.
For the first callback, res = res and cb.begin_fit will be res = res and True, res = True and True = True.
For the second, red = res and cb.begin_fit will be res = res and False, res = True and False = False.
For the third one, res = res and cb.begin_fit will be res = res And ?, res = False and ?. As python has short circuit logic with booleans, as False and anything will always be False, second part is not even evaluated, so the third callback doesnāt end up even running.
In the above example we return a xb,yb in the on_batch_begin handler. But according to the above logic we only return True/False. So where are the batches being returned?
res is supposed to be a flag here to indicate when to stop. when res=False it means STOP.
res is returned by the methods in CallbackHandler and it is used in the fit/train to determine wether to continue running or stop.
In fit we have if not cb.begin_fit(learn): return so if res=False then we āreturnā ie STOP.
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
so you start with res=True (continue running, do NOT STOP)
then you go through all the callbacks you have and check what begin_fit returns for each callback.
Now since we are anding these returned values, if any returned value is False then the whole thing evaluates to False. (ie kinda how short circuit works).
Short circuit:
taking @juvian 's example for this line - for cb in self.cbs: res = res and cb.begin_fit(learn)
you get on
iteration 1: res = True and True = True
iteration 2: res = True and False = False
iteration 3: res = False and ? = False. (short circuit, not caring about the second input if the first one can determine the answer)
Basically if any callback says STOP you want to STOP the whole thing.
Why is res turned to False now? this is dependent on how you set up your callback.(if i find a meaningful example iāll let you know, but for now i just made begin_fit to return False and it just stops, this is a stupid example)
This is what i have understood, please correct me if iām wrong
forgot to mention that for the current fastai v1 for the callback handler - something had to return True to mean keep going so false meant stop. in Python if you donāt add a return then it actually returns None which is false. so if you forget to return something that means keep going which should be the default. So if you are looking at the fit in the Runner the new logic says if rather than if not ie True means STOP.
I think youāre first question and the 2nd question (where you have pasted the ātrainā function) are two different versions of how CallbackHandlers were implemented. Iām not sure i need to check this too.
because there is no āon_begin_epochā in your code snippet. This is a simplified version. Looking at the docs your question about xb, yb is here -
def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
"Handle new batch `xb`,`yb` in `train` or validation."
self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
self('batch_begin', mets = not self.state_dict['train'])
return self.state_dict['last_input'], self.state_dict['last_target']
Yes, but that just shows how the callback is returning a dict of objects. How will those then be assigned back to xb,xy if our being epoch is somethign like:
def begin_epoch(self, epoch):
learn.model.train()
self.in_train=True
res = True
for cb in self.cbs: res = res and cb.begin_epoch(epoch)
return res
Some callbacks at begin epoch may return True/False and some objects. A example of this is in the image I posted there is loss, skip_backward = callbacks.on_loss_begin(loss)
The loss and skip_backward may be returned from two different callbacks. We need a way to store them during the res loop and returning at end right?