Teacher Forcing

Hello,
I am attempting to re-implement the translation notebook from lesson 10 in fastai 1.0. I am currently trying to add in teacher forcing, and have some questions.

First of all, it was much easier to implement the force step function here, since all I had to do was create a small callback (I include this here because this post, Lesson 11 teacher forcing in fastai v1, was wondering how to do it):

@dataclass
class ForceAnneal(LearnerCallback):
    learn:Learner
    def on_epoch_begin(self, epoch, **kwargs):
        self.learn.model.pr_force = (10 - epoch) * 0.1 if epoch < 10 else 0

However, I am not sure how to get Learn.fit to pass in the y values to the model so that the forcing can actually happen.

From reading the source code, it seems that Learn.fit calls fit, and fit calls loss_batch which actually calls the model:

    out = model(*xb)

Since the y values aren’t passed in here, it seems like the only option I have is to write my own fit function from scratch, and give up using a Learner. Is that indeed the case?

Edit: Could the y values be injected in an on_batch_begin callback? Such as:

def on_batch_begin(self,last_target,**kwargs):
    self.learn.model.target = last_target

If you want to change the target, you just have to return custom_input, custom_target at the end of on_batch_begin
Glad you like the callbacks and find them easy to use :wink:

3 Likes

I don’t want to change the target, just have it fed into (or available to) the forward pass of the model.

Edit: I tried the on_batch_begin callback from above and it seems to work.

1 Like

This was not 100% clear to me till I wrote the code out. Hopefully helpful to others:

@dataclass
class AppendBatchTargs(Callback):
    learn:Learner
    def __init__(self):
        super().__init__()
    def on_batch_begin(self, last_input, last_target, **kwargs):
        return {'last_input':(last_input, last_target), 'last_target':last_target}
1 Like