try to answer with the best of my knowledge in fastai1.
Below is the related code snippet from fit() (you would call this in training):
for epoch in pbar:
for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
xb, yb = cb_handler.on_batch_begin(xb, yb)
loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
if cb_handler.on_batch_end(loss): break
if not cb_handler.skip_validate and not learn.data.empty_val:
val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
if cb_handler.on_epoch_end(val_loss): break
Two things to notice:
- The callbacks functions are handled by
cb_handler). As you can see,
cb_handler doesn’t have
on_valid_begin, so you can’t define
on_valid_begin in your callback function.
- Validation step occurs after
on_batch_end, and it is handled by a pre-defined function called
validate. You can’t customize it.
So what you can do is to turn off
validate is called, such as on
on_batch_end. And then, you turn it back on
on_batch_begin. Something like this:
self.learn.model.decoder.teacher_forcing_ratio = 0.5 # input your desired ratio here
self.learn.model.decoder.teacher_forcing_ratio = 0.
you must put
**kwargs in the arguments of
on_batch_end methods. It is because
cb_handler would propagate
state_dict (which is a dict) to those methods.
**kwargs helps unpack the
state_dict as keyword arguments.
__init__ method you define is redundant. callback will automatically attach
learn as its attribute