Update Model attribute via Callback

I am trying to implement the image captioning method described in ‘Attend, show and tell’ paper.

For training, I have incorporated the Teacher forcing in the decoder part (attribute name teacher_forcing_ratio) . However, for validation, I want to skip the teacher forcing i.e to use decoder output from the previous step to predict for the current step.

I am not quite comfortable with the callbacks concept. I would like to know, whether, following class will work?

class TF_turnoff(Callback):
    def __init__(self, learn):
        self.learn = learn

    def on_valid_begin():
        self.learn.model.decoder.teacher_forcing_ratio = 0

Also, Once validation is done, does the ratio switch back to default (user-provided) or I have to do it explicitly.


Are you using fastai1 or fastai2?

@riven314 fastaiv1(1.0.6)

try to answer with the best of my knowledge in fastai1.

Below is the related code snippet from fit() (you would call this in training):

        for epoch in pbar:
            for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
                xb, yb = cb_handler.on_batch_begin(xb, yb)
                loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
                if cb_handler.on_batch_end(loss): break

            if not cb_handler.skip_validate and not learn.data.empty_val:
                val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
                                       cb_handler=cb_handler, pbar=pbar)
            else: val_loss=None
            if cb_handler.on_epoch_end(val_loss): break

Two things to notice:

  1. The callbacks functions are handled by CallbackHandler (i.e. cb_handler). As you can see, cb_handler doesn’t have on_valid_begin, so you can’t define on_valid_begin in your callback function.
  2. Validation step occurs after on_batch_end, and it is handled by a pre-defined function called validate. You can’t customize it.

So what you can do is to turn off decoder.teacher_forcing_ratio before validate is called, such as on on_batch_end. And then, you turn it back on on_batch_begin. Something like this:

class TF_turnoff(Callback):
    def on_batch_begin(**kwargs):
        self.learn.model.decoder.teacher_forcing_ratio = 0.5 # input your desired ratio here

    def on_batch_end(**kwargs):
        self.learn.model.decoder.teacher_forcing_ratio = 0.

Some remarks:

  1. you must put **kwargs in the arguments of on_batch_begin, on_batch_end methods. It is because cb_handler would propagate state_dict (which is a dict) to those methods. **kwargs helps unpack the state_dict as keyword arguments.

  2. the __init__ method you define is redundant. callback will automatically attach learn as its attribute

1 Like

@riven314 Thanks for the quick response.