I am trying to implement the image captioning method described in ‘Attend, show and tell’ paper.
For training, I have incorporated the Teacher forcing in the decoder part (attribute name teacher_forcing_ratio) . However, for validation, I want to skip the teacher forcing i.e to use decoder output from the previous step to predict for the current step.
I am not quite comfortable with the callbacks concept. I would like to know, whether, following class will work?
try:
for epoch in pbar:
learn.model.train()
cb_handler.set_dl(learn.data.train_dl)
cb_handler.on_epoch_begin()
for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
xb, yb = cb_handler.on_batch_begin(xb, yb)
loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
if cb_handler.on_batch_end(loss): break
if not cb_handler.skip_validate and not learn.data.empty_val:
val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
cb_handler=cb_handler, pbar=pbar)
else: val_loss=None
if cb_handler.on_epoch_end(val_loss): break
Two things to notice:
The callbacks functions are handled by CallbackHandler (i.e. cb_handler). As you can see, cb_handler doesn’t have on_valid_begin, so you can’t define on_valid_begin in your callback function.
Validation step occurs after on_batch_end, and it is handled by a pre-defined function called validate. You can’t customize it.
So what you can do is to turn off decoder.teacher_forcing_ratio before validate is called, such as on on_batch_end. And then, you turn it back on on_batch_begin. Something like this:
class TF_turnoff(Callback):
def on_batch_begin(**kwargs):
self.learn.model.decoder.teacher_forcing_ratio = 0.5 # input your desired ratio here
def on_batch_end(**kwargs):
self.learn.model.decoder.teacher_forcing_ratio = 0.
Some remarks:
you must put **kwargs in the arguments of on_batch_begin, on_batch_end methods. It is because cb_handler would propagate state_dict (which is a dict) to those methods. **kwargs helps unpack the state_dict as keyword arguments.
the __init__ method you define is redundant. callback will automatically attach learn as its attribute