Callbacks Super Class

I am getting a bit confused with how the classes are working when dealing with Call Back Handler.

The code of the handler is and a sample callback is given by:

class CallbackHandler():
     def __init__(self,cbs=None):
            self.cbs = cbs if cbs else []

 def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
    learn.stop = False
    res = True
    for cb in self.cbs: res = res and cb.begin_fit(learn)
    return res

def after_fit(self):
    res = not self.in_train
    for cb in self.cbs: res = res and cb.after_fit()
    return res

def begin_epoch(self, epoch):
    learn.model.train()
    self.in_train=True
    res = True
    for cb in self.cbs: res = res and cb.begin_epoch(epoch)
    return res

def begin_validate(self):
    self.learn.model.eval()
    self.in_train=False
    res = True
    for cb in self.cbs: res = res and cb.begin_validate()
    return res

and

class TimeCheck(Callback):
    def begin_fit(self, learn):
        self.learn = learn
        self.epoch_counter = 1
        return True
    
    def begin_epoch(self, epoch):
        self.epoch=epoch
        print(f'Epoch {self.epoch_counter} started at {time.strftime("%H:%M:%S", time.gmtime())}')
        self.epoch_counter += 1
        return True 

What is “res” doing in :
res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
Or in

res = not self.in_train
for cb in self.cbs: res = res and cb.after_fit()

I am unable to understand its function. ‘res’ is used in all the callbacks and I just dont get what its for. Can someone clarify? It is not used in any individual callback just the handler.

1 Like

Lets take begin_fit as example, but they are all the same. Lets assume there are 3 callbacks in cbs. The first one will return true, but the second one will return false. Third one doesn’t really matter.

For the first callback, res = res and cb.begin_fit will be res = res and True, res = True and True = True.
For the second, red = res and cb.begin_fit will be res = res and False, res = True and False = False.
For the third one, res = res and cb.begin_fit will be res = res And ?, res = False and ?. As python has short circuit logic with booleans, as False and anything will always be False, second part is not even evaluated, so the third callback doesn’t end up even running.

I’m sorry I don’t quite understand this. Why is res turned to False now?

l=[True,True,False,False,True]
res = True
for cb in l: 
  res = res and cb
  print(cb)

This code does not terminate in between and prints all elements of l.

How will we return an object then for example:

In the above example we return a xb,yb in the on_batch_begin handler. But according to the above logic we only return True/False. So where are the batches being returned?

res is supposed to be a flag here to indicate when to stop. when res=False it means STOP.
res is returned by the methods in CallbackHandler and it is used in the fit/train to determine wether to continue running or stop.
In fit we have if not cb.begin_fit(learn): return so if res=False then we “return” ie STOP.

res = True
for cb in self.cbs: res = res and cb.begin_fit(learn)
return res
  • so you start with res=True (continue running, do NOT STOP)
  • then you go through all the callbacks you have and check what begin_fit returns for each callback.
    Now since we are anding these returned values, if any returned value is False then the whole thing evaluates to False. (ie kinda how short circuit works).

Short circuit:
taking @juvian 's example for this line - for cb in self.cbs: res = res and cb.begin_fit(learn)
you get on
iteration 1: res = True and True = True
iteration 2: res = True and False = False
iteration 3: res = False and ? = False. (short circuit, not caring about the second input if the first one can determine the answer)
Basically if any callback says STOP you want to STOP the whole thing.

Why is res turned to False now? this is dependent on how you set up your callback.(if i find a meaningful example i’ll let you know, but for now i just made begin_fit to return False and it just stops, this is a stupid example)

This is what i have understood, please correct me if i’m wrong :slight_smile:

3 Likes

forgot to mention that for the current fastai v1 for the callback handler - something had to return True to mean keep going so false meant stop. in Python if you don’t add a return then it actually returns None which is false. so if you forget to return something that means keep going which should be the default. So if you are looking at the fit in the Runner the new logic says if rather than if not ie True means STOP.

I think you’re first question and the 2nd question (where you have pasted the “train” function) are two different versions of how CallbackHandlers were implemented. I’m not sure i need to check this too.
because there is no “on_begin_epoch” in your code snippet. This is a simplified version. Looking at the docs your question about xb, yb is here -

def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
        "Handle new batch `xb`,`yb` in `train` or validation."
        self.state_dict.update(dict(last_input=xb, last_target=yb, train=train, 
            stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
        self('batch_begin', mets = not self.state_dict['train'])
        return self.state_dict['last_input'], self.state_dict['last_target']

Thanks! I understood the thing with res.

But when we want to return something from the callback as in the cases with:

xb,xy=cb_handler.on_batch_begin(xb,yb)

and
out = cb_handler.on_loss_begin(out)

the above structure of returning res does not let you return an object from any of the callbacks. How is this taken place?

Did you see the last part of the answer

Yes, but that just shows how the callback is returning a dict of objects. How will those then be assigned back to xb,xy if our being epoch is somethign like:

def begin_epoch(self, epoch):
    learn.model.train()
    self.in_train=True
    res = True
    for cb in self.cbs: res = res and cb.begin_epoch(epoch)
    return res

Some callbacks at begin epoch may return True/False and some objects. A example of this is in the image I posted there is
loss, skip_backward = callbacks.on_loss_begin(loss)

The loss and skip_backward may be returned from two different callbacks. We need a way to store them during the res loop and returning at end right?

Hey? Did you read my above question?