I’m trying to understand how the data preparation for a language model works. The get_batch function in LanguageModelLoader returns
I know the purpose of the model is to predict the next word given the preceding sequence, so I was expecting a sample to be a sequence and the label to be the word following that sequence, say data[0:50] and data. However, it seems that the sample and label have the same length, just shifted over 1, so something like data[0:50] and data[1:51]. I can’t quite wrap my mind around how this is working.
The language model tries to predict the next word after each word in the sequence.
Here you find a nice illustration of the language model setup on the left side of the figure: https://twitter.com/thom_wolf/status/1186225108282757120?s=21
The reason that the labels are, as you put it, the sequence “just shifted over by one,” is because our language model needs to be able to make a prediction for the most appropriate “subsequent word” for each word in an input sequence of
bptt (back-propagation through time) length.
To zoom out to a slightly higher level: in one iteration of one epoch, the model makes
bptt number of these “next word” predictions on all batches in parallel. Cross-entropy loss is then calculated between each predicted word and the real word that actually came next in the training corpus.
Now how do we ensure that the text we look at in the next iteration directly follows the text that was seen in the previous iteration? (If we don’t ensure this, all the information that’s stored in our model’s hidden state at the end of an epoch’s iteration will be wasted.) The approach is that before training begins, we divide the corpus into
batch_size number of text chunks. The length of these chunks is chosen such that if each batch handles, in parallel, consecutive
bptt length sub-sequences of these chunks across all iterations (one sequence is handled in one iteration), then there will still be a target word to be predicted for the final word appearing in the final bptt sequence seen by the final batch.
To put it more succinctly, we batch the corpus in such a way that ensures that:
- The text sequences that the model sees in one iteration directly precede (in the original corpus) the sequences the model sees in the next iteration.
- Each word our model sees as a training input will always have a corresponding “next word” that can be used as its label.
When I first went through the lesson 12 text NBs, I had a confusion/frustration similar to what you described. On the outside chance you find it helpful: here’s my own exploration of language model batching, in which I created several toy examples and accompanying explanations to convince myself I understood what was going on.
Thanks for taking the time to lay it out like this! I think I’m going to have to go through the video and your notebook a few times before I fully wrap my mind around it, but this is super helpful. Good to know it was confusing to someone else and I’m not just an idiot!
Given token(n), a language model predicts token(n+1).
So it would be wasteful if we used an input sequence of 50 tokens only to predict the 51st token!
The sequence created by shifting the input sequence leftward by 1 token is the output expected from a perfect language model, and is thus the ideal training ‘target’. This way we use every token of the input to teach our language model to predict the next token.