Hi Fabio,
I share your opinion. It seems that teacher forcing is always used, even at evaluation (in the notebook 8-translation-transformer).
This has several adverse effects:
This leaks ground truth contexts. This could increase the BLEUE metric and give a false representation of the model ability to perform full translation (instead, the task becomes to translate the next word in an already known partial translation).
This could also result in low-quality output sentences (incoherence, repeated words), because the transformer is given a context that does not match what was really generated.
I haven’t yet watched Rachel Thomas course videos, maybe using teacher forcing at eval was intended or not the focus of this lesson.
In my tests on this notebook, without teaching forcing, I observed lower BLUE scores but more coherent output sentences overall .
I am not sure of how auto-regressive prediction is usually implemented with Transformers.
Here is my attempt, only modifying the forward
function of the Tranformer
class Transformer(Module):
def __init__(self, inp_vsz, out_vsz, n_layers=6, n_heads=8, d_model=256, d_head=32,
d_inner=1024, p=0.1, bias=True, scale=True, double_drop=True, pad_idx=1):
self.enc_emb = TransformerEmbedding(inp_vsz, d_model, p)
self.dec_emb = TransformerEmbedding(out_vsz, d_model, 0.)
args = (n_heads, d_model, d_head, d_inner, p, bias, scale, double_drop)
self.encoder = nn.ModuleList([EncoderBlock(*args) for _ in range(n_layers)])
self.decoder = nn.ModuleList([DecoderBlock(*args) for _ in range(n_layers)])
self.out = nn.Linear(d_model, out_vsz)
self.out.weight = self.dec_emb.embed.weight
self.pad_idx = pad_idx
def forward(self, inp, out):
enc = self.enc_emb(inp)
enc = compose(self.encoder)(enc)
if self.training: # batch teacher forcing for training
out = self.dec_emb(out)
mask_out = get_output_mask(out, self.pad_idx)
out = compose(self.decoder)(out, enc, mask_out)
return self.out(out)
else: # Auto-regressive prediction for eval:
out_seq_lenght = out.shape[1] # Only use output length (TODO: set this as forward param)
xxpad_token = 1 # Todo set this as module param (and use databunch dictionnary)
# Starts with only 1 xxpad tokens
next_dec_input = torch.full((inp.shape[0], 1), fill_value=xxpad_token, dtype=inp.dtype, device=inp.device)
# Predict one token at a time:
for _ in range(out_seq_lenght):
out = self.dec_emb(next_dec_input)
mask_out = get_output_mask(out, self.pad_idx)
out = compose(self.decoder)(out, enc, mask_out)
out = self.out(out)
# Prepare next input (xxpad + last predictions):
next_dec_input = F.pad(out.argmax(2), (1, 0), value=xxpad_token)
return out
I haven’t tested this implementation thoroughly , but it runs (slower than with teacher forcing because of the multiple passes per batch).
I think it could be improved. For example, could-it be faster to not recompute the hidden representations of previous predicted tokens ? (Applying dynamic programming to Transformer auto-regressive decoder inference.) Maybe I should look at how this is done in Pytorch and Tensforflow implementations.