The reason that the labels are, as you put it, the sequence “just shifted over by one,” is because our language model needs to be able to make a prediction for the most appropriate “subsequent word” for each word in an input sequence of
bptt (back-propagation through time) length.
To zoom out to a slightly higher level: in one iteration of one epoch, the model makes
bptt number of these “next word” predictions on all batches in parallel. Cross-entropy loss is then calculated between each predicted word and the real word that actually came next in the training corpus.
Now how do we ensure that the text we look at in the next iteration directly follows the text that was seen in the previous iteration? (If we don’t ensure this, all the information that’s stored in our model’s hidden state at the end of an epoch’s iteration will be wasted.) The approach is that before training begins, we divide the corpus into
batch_size number of text chunks. The length of these chunks is chosen such that if each batch handles, in parallel, consecutive
bptt length sub-sequences of these chunks across all iterations (one sequence is handled in one iteration), then there will still be a target word to be predicted for the final word appearing in the final bptt sequence seen by the final batch.
To put it more succinctly, we batch the corpus in such a way that ensures that:
- The text sequences that the model sees in one iteration directly precede (in the original corpus) the sequences the model sees in the next iteration.
- Each word our model sees as a training input will always have a corresponding “next word” that can be used as its label.
When I first went through the lesson 12 text NBs, I had a confusion/frustration similar to what you described. On the outside chance you find it helpful: here’s my own exploration of language model batching, in which I created several toy examples and accompanying explanations to convince myself I understood what was going on.