Hi Fabio,
I share your opinion. It seems that teacher forcing is always used, even at evaluation (in the notebook 8-translation-transformer).
This has several adverse effects:
-
This leaks ground truth contexts. This could increase the BLEUE metric and give a false representation of the model ability to perform full translation (instead, the task becomes to translate the next word in an already known partial translation).
-
This could also result in low-quality output sentences (incoherence, repeated words), because the transformer is given a context that does not match what was really generated.
I haven’t yet watched Rachel Thomas course videos, maybe using teacher forcing at eval was intended or not the focus of this lesson.
In my tests on this notebook, without teaching forcing, I observed lower BLUE scores but more coherent output sentences overall .
I am not sure of how auto-regressive prediction is usually implemented with Transformers.
Here is my attempt, only modifying the forward
function of the Tranformer
class:
class Transformer(Module):
def __init__(self, inp_vsz, out_vsz, n_layers=6, n_heads=8, d_model=256, d_head=32,
d_inner=1024, p=0.1, bias=True, scale=True, double_drop=True, pad_idx=1):
self.enc_emb = TransformerEmbedding(inp_vsz, d_model, p)
self.dec_emb = TransformerEmbedding(out_vsz, d_model, 0.)
args = (n_heads, d_model, d_head, d_inner, p, bias, scale, double_drop)
self.encoder = nn.ModuleList([EncoderBlock(*args) for _ in range(n_layers)])
self.decoder = nn.ModuleList([DecoderBlock(*args) for _ in range(n_layers)])
self.out = nn.Linear(d_model, out_vsz)
self.out.weight = self.dec_emb.embed.weight
self.pad_idx = pad_idx
def forward(self, inp, out):
enc = self.enc_emb(inp)
enc = compose(self.encoder)(enc)
if self.training: # batch teacher forcing for training
out = self.dec_emb(out)
mask_out = get_output_mask(out, self.pad_idx)
out = compose(self.decoder)(out, enc, mask_out)
return self.out(out)
else: # Auto-regressive prediction for eval:
out_seq_lenght = out.shape[1] # Only use output length (TODO: set this as forward param)
xxpad_token = 1 # Todo set this as module param (and use databunch dictionnary)
# Starts with only 1 xxpad tokens
next_dec_input = torch.full((inp.shape[0], 1), fill_value=xxpad_token, dtype=inp.dtype, device=inp.device)
# Predict one token at a time:
for _ in range(out_seq_lenght):
out = self.dec_emb(next_dec_input)
mask_out = get_output_mask(out, self.pad_idx)
out = compose(self.decoder)(out, enc, mask_out)
out = self.out(out)
# Prepare next input (xxpad + last predictions):
next_dec_input = F.pad(out.argmax(2), (1, 0), value=xxpad_token)
return out
I haven’t tested this implementation thoroughly , but it runs (slower than with teacher forcing because of the multiple passes per batch).
I think it could be improved. For example, could-it be faster to not recompute the hidden representations of previous predicted tokens ? (Applying dynamic programming to Transformer auto-regressive decoder inference.) Maybe I should look at how this is done in Pytorch and Tensforflow implementations.