`MultiBatchRNN` and `PoolingLinearClassifier`


I was reading the code that defines the class MultiBatchRNN and PoolingLinearClassifier (link) and I just wanted to confirm my understanding of how they work. To illustrate, this is what I think happens:

Imagine our input sentences are all 12 words long and our batch size is 5, so the block is shaped (12, 5). Then assume we define our bptt as 4, so the first step is to divide each sentence into 3 parts, each is 4 words long, colored in blue, yellow, and red.

Then, each colored block is fed into RNN_Encoder one by one so we end up with a list of 3 outputs. Now we concatenate them in such a way that the first rows from each output are regrouped together into a list, the second rows together, and so on, so we end up with a list of 4 outputs, each is shaped (3, 5), like the middle block above.

Moving on to PoolingLinearClassifier, it keeps the last output only (i.e., 4, 8, 12 above) and applies average and max poolings to it. Last, it combines them with the last row (12) into one long vector, and feed it into the linear layers (i.e., torch.cat([output[-1], mxpool, avgpool], 1)).

Phew… Am I right in my interpretation?



what is the meaning of “regroup”?

That’s right! Very good explanation!

If you take a look at the final shape of your Tensor it should be something like [bs,1200].
That 1200 comes from 400 + 400 + 400 that are the the concatenation of [last(lastOut), maxPool(lastOut), avgPool(lastOut)].
WHERE: lastOut is the last output ([4,8,12] in your explanation).

NOTE: 400 is the output size of the last LSTM layer in the Encoder…

(0): MultiBatchEncoder(
    (module): AWD_LSTM(
      (encoder): Embedding(60004, 400, padding_idx=1)
      (encoder_dp): EmbeddingDropout(
        (emb): Embedding(60004, 400, padding_idx=1)
      (rnns): ModuleList(
        (0): WeightDropout(
          (module): LSTM(400, 1150, batch_first=True)
        (1): WeightDropout(
          (module): LSTM(1150, 1150, batch_first=True)
        (2): WeightDropout(
          (module): LSTM(1150, 400, batch_first=True)
      (input_dp): RNNDropout()
      (hidden_dps): ModuleList(
        (0): RNNDropout()
        (1): RNNDropout()
        (2): RNNDropout()

AFAIK Regrouping is a kind of “activations engineering” (feature engineering in the domain of activations :wink: - something similar to DenseNet) to take into account information from all the important parts of the lastOutput and not only the very last Tensor.