"Cheating" in transformer notebook?

Hey community,

I’m currently working through the transformer notebook and wonder why we pass the shifted target sequence to the model also during inference.

def get_predictions(learn, ds_type=DatasetType.Valid):
    inputs, targets, outputs = [],[],[]
    with torch.no_grad():
        for xb,yb in progress_bar(learn.dl(ds_type)):
            assert (xb[1][:, 1:] == yb[:, :-1]).all()
            out = learn.model(*xb)
            for x,y,z in zip(xb[0],xb[1],out):
    return inputs, targets, outputs

The assert statement does not fail.

Doesn’t this mean that we are constantly teacher forcing also during inference?

Glad if anyone could help me understand this! :slight_smile:

Thanks and best regards,



In which line do you see teacher forcing?

Only xb is feed into the model.

The assert checks for equality (==) and does not assign (=).

AFAIK teacher forcing is something you use during training and not for inference, but maybe I’m wrong?

Sorry for not being more explicit :slight_smile:
xb is a tuple where the first entry is x, the input sequence, and the second entry is y, the target sequence, shifted by one. I added the assert statement to show this.
This means that the target sequence is actually always passed to the model also during inference in this notebook.

Hi Fabio,

I share your opinion. It seems that teacher forcing is always used, even at evaluation (in the notebook 8-translation-transformer).

This has several adverse effects:

  • This leaks ground truth contexts. This could increase the BLEUE metric and give a false representation of the model ability to perform full translation (instead, the task becomes to translate the next word in an already known partial translation).

  • This could also result in low-quality output sentences (incoherence, repeated words), because the transformer is given a context that does not match what was really generated.

I haven’t yet watched Rachel Thomas course videos, maybe using teacher forcing at eval was intended or not the focus of this lesson.

In my tests on this notebook, without teaching forcing, I observed lower BLUE scores but more coherent output sentences overall .

I am not sure of how auto-regressive prediction is usually implemented with Transformers.
Here is my attempt, only modifying the forward function of the Tranformer class:

class Transformer(Module):
    def __init__(self, inp_vsz, out_vsz, n_layers=6, n_heads=8, d_model=256, d_head=32, 
                 d_inner=1024, p=0.1, bias=True, scale=True, double_drop=True, pad_idx=1):
        self.enc_emb = TransformerEmbedding(inp_vsz, d_model, p)
        self.dec_emb = TransformerEmbedding(out_vsz, d_model, 0.)
        args = (n_heads, d_model, d_head, d_inner, p, bias, scale, double_drop)
        self.encoder = nn.ModuleList([EncoderBlock(*args) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderBlock(*args) for _ in range(n_layers)])
        self.out = nn.Linear(d_model, out_vsz)
        self.out.weight = self.dec_emb.embed.weight
        self.pad_idx = pad_idx
    def forward(self, inp, out):
        enc = self.enc_emb(inp)
        enc = compose(self.encoder)(enc)
        if self.training: # batch teacher forcing for training
            out = self.dec_emb(out)
            mask_out = get_output_mask(out, self.pad_idx)
            out = compose(self.decoder)(out, enc, mask_out)
            return self.out(out)
        else: # Auto-regressive prediction for eval:
            out_seq_lenght = out.shape[1] # Only use output length (TODO: set this as forward param)
            xxpad_token = 1 # Todo set this as module param (and use databunch dictionnary)
            # Starts with only 1 xxpad tokens
            next_dec_input = torch.full((inp.shape[0], 1), fill_value=xxpad_token, dtype=inp.dtype, device=inp.device)
            # Predict one token at a time:
            for _ in range(out_seq_lenght): 
                out = self.dec_emb(next_dec_input)
                mask_out = get_output_mask(out, self.pad_idx)
                out = compose(self.decoder)(out, enc, mask_out)
                out = self.out(out)
                # Prepare next input (xxpad + last predictions):
                next_dec_input =  F.pad(out.argmax(2), (1, 0), value=xxpad_token)
            return out

I haven’t tested this implementation thoroughly , but it runs (slower than with teacher forcing because of the multiple passes per batch).

I think it could be improved. For example, could-it be faster to not recompute the hidden representations of previous predicted tokens ? (Applying dynamic programming to Transformer auto-regressive decoder inference.) Maybe I should look at how this is done in Pytorch and Tensforflow implementations.

1 Like

The Google implementation is interesting : https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L289-L351

They use a beam search and a cache to reuse decoder attention K and V tensors for each layer.

I don’t think that was intended. If I remember correctly, teacher forcing was only discussed in the context of RNN-based seq2seq models. The implemented version actually started with teacher forcing but slowly phased it out towards the end of training.

That makes sense. When I applied the notebook to another dataset, I found it quite strange that for a relatively high BLEU score the translations were pretty bad and repetitive.

I guess it will always be slower than with teacher forcing because you need to go sequentially in order to feed the previous output as current input.

Wouldn’t beam search make it even slower, since the model has to generate beam_width times more predictions at each step? On the other hand, at each step the predictions can be processed parallely so maybe it doesn’t matter too much.

Yes beam search should always be slower than a greedy search (at least not faster).
However, for both sampling techniques, caching keys and values for previously predicted tokens should help.

I saw that hugging face Transformers generate() method has some support for this (parameter use_cache).

1 Like

That’s good to know. Speed was in fact a bit of a problem when I wanted to process large amounts of data with the huggingface open source NMT models. Next time I’ll try with caching.

Its early days but I have been using PyTorch’s nn.Transformer module with fastai for a translation project I am working on in case its useful for anyone here. I had played around with the fastai nlp course code but I couldn’t quite get it to work as expected (although it was still super useful to step though how the encoder/decoder/attention etc modules were defined).

Not using any clever generate methods for now, but maybe you might find this notebook useful as a good baseline / sanity check

1 Like