Callback discussion from lesson 9

If I recall correctly, the base class had _order defined, so it should just default to that value.

1 Like

Ah. Got it. Thanks!

another doubt I had is w.r.to getattr. I am guessing it is required because the callbacks will be accessing and modifying the variables in the runner. But would it not affect the getattr code in call function below ? Would it not look for the attribute in the runner instead of the Callbacks ?

class Callback():
  _order = 0
  def set_runner(self, run): self.run = run
  def __getattr__(self, k): return getattr(self.run, k)
  
  @property
  def name(self):
    name = re.sub(r'Callback$', '', self.__class__.__name__)
    return camel2snake(name or 'callback')
  
  def __call__(self, cb_name):
    f = getattr(self, cb_name, None)
    if f and f():
      return True
    return False

__getattr__ is only called by Python when it cannot find the attribute in the class itself, it is not the default behaviour of getattr(). So getattr() will still be able to find attributes in the callback.

1 Like

The idea of callbacks is relatively easy to understand. What is hard to understand is all of the complexity of implementing if in fastai notebook. I am combing through the lesson code and it doesn’t look very easy to understand, especially considering the lack of comments about what the code is doing.

2 Likes

I find it easiest to understand something when you try to implement it. For example, I wrote this Callback to assist in implementing the Stanford ML Group’s MRNet (see this thread) via the fastai library.

TL;DR – MRNet is implemented by stacking a series of images in a single batch, then squeezing the “batch” to conform to the expected shape for input into the neural net.

So, I decided to implement a Callback that does this on_begin_batch. It took digging through the fastai library – reviewing code for the Callback-related classes and looking at examples – to understand how a Callback is implemented via CallbackHandler.

In fastai==1.0.50, the CallbackHandler maintains a state_dict that facilitates manipulation of variables within the training loop. I found it helpful to instantiate CallbackHandler in my notebook to see how state_dict evolved after a call to cbh.on_batch_begin() in order to understand how to implement the change I wanted to make.

The end result was this:

class MRNetCallback(Callback):
    def on_batch_begin(self, last_input, **kwargs):
        x = torch.squeeze(last_input, dim=0)
        return dict(last_input=x)

I haven’t yet tried to implement this via the new Callbacks system covered in recent lessons, but I imagine things will change slightly and perhaps be even easier to understand/implement.

1 Like

Below is my take on understanding callbacks as presented in the lesson (notebook 04_callbacks.ipynb). Note that the idea of passing a function as an option wasn’t new to me so the basic explanation of callback was easy to grasp. What wasn’t easy is untangling the complexity of implementation. After some time, I wrote the below for myself with the purpose to understand better by writing.

Refactoring fit(), all_batches() and one_batch()

First I refactored these functions by replacing cb with cb_handler. The former naming confused me because cb was used in different places both as callback and as callback handler. I kept cb to mean callback and cb_handler to mean the handler. After this the code became more clear to me.

Explaining Callback class:

All the methods in the class do some things (side effects) and return a boolean, which is a signal to either do the step or skip it. All the default methods of this class return True by default (True means continue executing each corresponding step of the training loop, and False will propagate to the handler and will stop the corresponding step from executing). That means that by default, all the steps of training loop are executed.

Callback instances aren’t called by the fit method (loop) directly. Instead, they are all handled by corresponding methods of CallbackHandler class. You are supposed to subclass this class to make your own custom callbacks and let the callback handler handle it.

The main way Callback can do things is by getting access to the self.learn property which keeps reference to the learner object. That will allow to access and modify the state of the model, optimizer, datasets and dataloader at different stages of training loop. Every Callback is registered with learn when the fit() function calls begin_fit(learn) on the callback handler which passes the reference to the learn object to all callbacks.

In addition, Callback will get access to the epoch number, as well as batches and loss (only after the corresponding step). All these will be stored as parts of self and accessible later. I bolded these methods for clarity below.

  • begin_fit(self, learn): initializes and stores the learner inside the self.learn - this basically gives the callback access to all underlying training components: model, optimizer, dataloader and datasets
  • after_fit(self)
  • begin_epoch(self, epoch): since it’s beginning of epoch, only gets access to the epoch number and stores it internally.
  • begin_validate(self)
  • after_epoch(self)
  • begin_batch(self, xb, yb): stores batch inside as self.xb and self.yb
  • after_loss(self, loss): stores loss inside self.loss
  • after_backward(self)
  • after_step(self)

Explaining CallbackHandler class

The purpose of this class is to store multiple callbacks and call them at the appropriate moments during the training loop. Methods of CallbackHandler instance are called directly from the fit function. Each method of CallbackHandler is mirrored by the corresponding method of the Callback class (with the exception of __init__ and do_stop methods).

CallbackHandler is initialized with a list of zero or more callback objects. Note: if the list is empty, all methods of the handler will return True by default and therefore the fit function will go through all the steps without any interruption.

To reiterate, the goal of callback handler is two-fold:

  1. at each step to go through all callbacks and make them do their job by accessing self.learn
  2. to propagate stopping signal if any of the callbacks demands it

The first part is done by going through all callbacks at each step and calling it with the corresponding method. Part two is done by chaining boolean returns using and statements with the default initial value of True. If any of the callbacks will return False then the res variable also becomes False and that eventually propagates up to the fit function which will either:

  1. return - for begin_batch(), after_loss(), do_stop() and begin_fit() methods
  2. continue - for begin_epoch() method
  3. break - for do_stop() and after_epoch() methods
  4. do a specific step - for after_backward(), after_step() and begin_validate() methods
8 Likes

Note that in a later lesson we refactored this quite a bit - nb 09b IIRC - and the new version I think is quite a bit easier to understand. It uses exceptions instead of conditionals, amongst other things.

I kept that in mind while trying to understand the initial callback system. I did it anyway because I think it’s important for learning purposes to understand how that system works and then move to the refactored version to appreciate and understand the change better.

4 Likes

Mostly patterns where designed to facilitate EJB development in the late nineteen nineties early 2 thousands. Here is a link to some further pattern info
M Fowler Patterns and Software Design

@RogerS49 What is EJB development?

Not really a topic for these forums
A Unix/Linux technology that allows web applications to talk to data back ends on the servers, Windows had their own proprietary technology.

Entity Java Beans (EJB) is a Java technology for Enterprise application. While a user may be a single entity there may be many hundreds of users all at the same time trying to access the application hosted on multiple servers. To enable this the application must work to allow multiple servers to exists, be resilient (fail over without the users knowing) and maintain integrity (not lose or destroy data).

Not sure how this is achieved today, it started use around the year 2000, instigated by Sun and now part of Oracle Technologies.

To place it in perspective with fastai it is not something you would use in building models, but the technologies whether EJB or similar is something you would need in production if you made an application for users to use from your model.

There are in EJB technologies many places to plugin alternative behaviours in a similar manner to callbacks but perhaps not all achieved by the same mechanisms. I gave the links above

  • Because Martin Fowler is renowned for Analysis of software
  • He expands on the original patterns
  • Many times patterns are combined to solve problems.
  • To provide a fuller view of how patterns help accomplish software tasks

ERRATA

Should there be 2 begin_epoch in ALL_CBS in nb_09b.py

ALL_CBS = {'begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step',
    'after_cancel_batch', 'after_batch', 'after_cancel_epoch', 'begin_fit',
    'begin_epoch', 'begin_epoch', 'begin_validate', 'after_epoch',
    'after_cancel_train', 'after_fit'}

Reading the callback documentation page, I ran into trouble early on:

The first code box is prefaced by a remark to the effect of how “easy” callbacks are

But “easy” is not a word I’d use to describe this code:

@dataclass
class GradientClipping(LearnerCallback):
    clip:float
    def on_backward_end(self, **kwargs):
        if self.clip:
            nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)

It would be great if there were clear, concise documentation unpacking this code

The next code box is also a puzzle:

@dataclass
class MyCallback(Callback):
    learn:Learner

Can anyone provide insight into what the above code snippet does?

(1) I’m unfamiliar with this usage of the colon : in Python, and have not found an explanation online What does learn:Learner do?
Subsequent note: (1) is answered by this documentation of the usage of a colon to assign a type in Python So I now understand that learn:Learner defines learn as an object of class Learner

(2) I don’t understand the use of the decorator @dataclass; a search for dataclass in the fastai docs does not turn up a definition

Comment: the Greek alphabet presents a trivial barrier to understanding compared to CCWD (Cryptic Code Without Documentation) :wink:

@jcatanza, to 2, the @ symbol is a decorator for a dataclass.

(Disregard the next bit, it’s a bit off topic but it was some python stuff I didn’t quite know, and I misread the question at first, so if you know what a decorator is skip to the bottom bit. I did not know this beforehand!)

See here but essentially my take on it is a decorator is used for anything that modifies a function or class.

For example, let’s look at TabularList’s source code:

class TabularList(ItemList):
    "Basic `ItemList` for tabular data."
    _item_cls=TabularLine
    _processor=TabularProcessor
    _bunch=TabularDataBunch
    def __init__(self, items:Iterator, cat_names:OptStrList=None, cont_names:OptStrList=None,
                 procs=None, **kwargs)->'TabularList':
        super().__init__(range_of(items), **kwargs)
        #dataframe is in inner_df, items is just a range of index
        if cat_names is None:  cat_names = []
        if cont_names is None: cont_names = []
        self.cat_names,self.cont_names,self.procs = cat_names,cont_names,procs
        self.copy_new += ['cat_names', 'cont_names', 'procs']
        self.preprocessed = False

    @classmethod
    def from_df(cls, df:DataFrame, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList':
        "Get the list of inputs in the `col` of `path/csv_name`."
        return cls(items=range(len(df)), cat_names=cat_names, cont_names=cont_names, procs=procs, inner_df=df.copy(), **kwargs)

    def get(self, o):
        if not self.preprocessed: return self.inner_df.iloc[o] if hasattr(self, 'inner_df') else self.items[o]
        codes = [] if self.codes is None else self.codes[o]
        conts = [] if self.conts is None else self.conts[o]
        return self._item_cls(codes, conts, self.classes, self.col_names)

When we create a TabularList, we call from_df, and since it is a class method, we have a special cls call which means it will call the constructor with the items passed in for us. This can then be applied to any instance we want, eg take ImageDataBunch. we have things like create_from_ll, from_folder, from_df, etc.

Whereas we have @staticmethod of single_from_classes, which is used only after a function is declared.

Another useful link on Class vs Static geeksforgeeks

Difference (verbatum)
Class method vs Static Method

  • A class method takes cls as first parameter while a static method needs no specific parameters.
  • A class method can access or modify class state while a static method can’t access or modify it.
  • In general, static methods know nothing about class state. They are utility type methods that take some parameters and work upon those parameters. On the other hand class methods must have class as parameter.
  • We use @classmethod decorator in python to create a class method and we use @staticmethod decorator to create a static method in python.

When to use what?

  • We generally use class method to create factory methods. Factory methods return class object ( similar to a constructor ) for different use cases.
  • We generally use static methods to create utility functions.

This was mostly to help me understand those decorators, as I didn’t know what they were!

ACTUAL ANSWER
Now to @dataclass, the python documentation has it here

Data classes are geared more towards storing a particular state than logic-based, with less boilerplate pain. Another explaination is here

Hope that helps, and if you see anything wrong with that explanation upon further inspection/explanation let me know, I’m learning python as I go through fastai :slight_smile:

4 Likes

Thanks for the detailed answer @muellerzr!

1 Like

Going back to the LRFinder callback from the beginning of the thread, one thing that I did not get is the use of “smooth_loss”. I did not understand what it is exactly and who passes it. Apart from the callback itself and calls to it, could not find it anywhere else in the fastai source code.

Can someone help me understand what is the smooth_loss? are there “smoot_grads” as well? Where in documentation can I find that?

I’m also slightly confused: there seem to be two types of passing callbacks to a Runner (or later in notebook 9b) a Learner: with cbs and cb_funcs. For example passing either of these two lines seem to give the same result

cbs = [AvgStatsCallback(accuracy)]
cbfs = [partial(AvgStatsCallback,accuracy)

The cbs simple needs an instance, where the function does it later for you.

Is there any preference when one of the two is better or why there is not only one way?

4 Likes

That’s a great explanation.
I am still trying to figure out the last bit in detail. For each function, why is there a different thing(like return, continue etc)? Also the way they are presented in the notebook. For some functions, the condition check is negation like if not cb_handler.begin_batch(cb, yb) while the one that do a specific step do not have that.

This is an excellent explaination of callbacks, highly recommended