How does BPT3C work?

I’m trying to wrap my head around how BPT3C (implemented in MultiBatchRNN) works…

When we have a normal simple RNN-based network, we pass in a minibatch of size (bptt, bs, emb_sz), and backprop from the loss to update all the weights in the network. Since we preserve the hidden state of our rnn layer between minibatches, we need to call repackage_var to make sure that we’re not creating essentially a thousands-layers-deep NN. I.e. Each time we backprop the loss, the network weights are updated based only on a history of length bptt of timepoints in the past.

The point of BPT3C is to extend the number of timepoints into the past that we are considering. It does this through repeated calls to RNN_Encoder, and appending the results into a list. The output of RNN_Encoder for a given rnn-layer is of size (bptt,bs,nh), so the overall result if we call RNN_Encoder e.g. 10x is going to be a tensor of size (bptt*10,bs,nh). The number of times we append to the list is controlled by the max_seq parameter.

Since we want to consider a greater number of timepoints, first question is why not just set bptt to a larger value? If bptt is 100 and max_seq is 1000, then we append to the list 10x, and wind up with a tensor of size (1000,bs,nh). Why not just set bptt to 1000, and we’d still wind up with an output tensor of size (1000,bs,nh).

Also, each time we call RNN_Encoder it uses repackage_var() to reset the rnn’s hidden state. If we want to preserve the history to a greater number of timepoints, why are we still calling repackage_var with each call to RNN_Encoder?

Any insight is appreciated, thanks.

These are good questions - and thinking about them is a good way to deepen your understanding of training RNNs.

It would use up all our GPU RAM and compute and lead to gradient explosion.

It’s related to the above answer. The issue is that we can’t backprop through so many time steps - due to compute and gradient explosion problems. Backpropping through thousands of time steps is really slow!

But we still want the earlier paragraphs to impact the weights. So we backprop the classifier results to each BPTT sequence individually. Therefore it doesn’t backprop through the entire sequence, but just through each BPTT section.

Ok that helps, but I’m still confused:

Normally (without BPT3C) the process is: 1) minibatch loaded and passed through embeddings for size (bptt, bs, emb_sz), 2) the bptt-minibatch is passed through the encoder, 3) bptt-minibatch is passed through the decoder, 4) loss is calculated, 5) backprop for all the weights in the NN

So, without BPT3C, the following weights are updated all at once: 1) the embeddings for all the words in the bptt-sequence, 2) the rnn encoder’s weights for bptt-length past timesteps, and 3) the decoder’s weights.

With BPT3C the process appears to be 1) minibatch loaded and passed through embeddings for size (bptt, bs, emb_sz), 2) the bptt-minibatch is passed through RNN_Encoder, 3) gradient history for rnn is erased through repackage_var, 4) steps 2 and 3 repeated n_loops times, with results appended to a list, 5) appended-list of size (bptt*n_loops, bs) is passed through the decoder, 6) loss is calculated, 7) backprop for all the weights in the NN.

So, with BPT3C, the following weights are updated all at once: 1) the word-embeddings for ONLY bptt timesteps in the past, 2) the rnn encoder’s weights for ONLY the previous bptt timesteps in the past, 3) the decoder’s weights.

For example, if document length is 2k, bptt is 100, and n_loops is 10, (so n_loops*bptt = 1K), then this is what happens: For the first 1K words, we loop through RNN_Encoder 10 times. The only effect of this is to build the RNN encoder’s hidden state, since after each loop we discard the gradient history and do not append the results to our encoder output. For the next 900 words, we are still discarding the history of gradients, but appending the results to the encoder’s output. Finally, for the last 100 words, we maintain the gradients of the rnn_encoder for those 100 timesteps, and append the results. We then pass the encoder’s output, of size (1k,bs,nh) to the decoder for classifier prediction, and then finally compute loss. Backprop only computes gradients for the most recent 100 timesteps of the rnn, and therefore only the embeddings for the most recent 100 words are updated as well.

Is all that correct? That doesn’t seem to jive with what you said “we backprop the classifier results to each BPTT sequence individually”. What am I missing?

2 Likes

I do not understand the same thing…