Tensor Manipulation: Splitting then Padding a New Dimension

Within PyTorch, lets say I have a tensor of torch.Size([32, 53, 768]). How would I go about converting it to torch.Size([32, 12, 5, 768])? When this conversion is computed, the dimension with original size 53 should be split at variable indexes (there will be 11 indexes so that 12 new sequences are formed). If the distance between two indexes is less than 5 then zeros should be added so that each sequence is 5 units long.

Essentially, I would like to split a sequence at certain indexes, pad the new sequences to some value, and end up with a tensor of torch.Size([15, 10]) (if starting with torch.Size([100]), splitting to create 15 sequences, and padding to 10).

I would like to do this without a loop since it will happen in the forward pass of a model and loops dramatically decrease performance in my understanding. It would be fine if a loop was used if it would not greatly impact model training time or performance.

My end goal is to be able to take the mean across the newly formed padded dimension and, in the case of the first example, end up with torch.Size([32, 12, 1, 768].

For the padding, all sequences should be padded equally. So, once the large original sequence is split at the indexes provided, each sub-sequence will be padded to the constant length.

Example Image:

Thank you. Any tips or suggestions are welcome.

It’s always risky to argue that a solution cannot exist. A real expert (or innocent) could always come up with a black swan.

But anyway, though I am not a PyTorch expert, IMHO you will need to use some kind of loop to break up the sequence, whether directly or with indexes. PyTorch is specifically designed to work with rectangular things, not ragged arrays. So I doubt there’s a non-loop way to be found.

As for the means, there’s no way to tell PyTorch to ignore a specific value. However, you could pad with zeros, and calculate the correct means directly:

xxx.sum(dim=0) / (xxx != 0).sum(dim=0).float().

You might even have saved the number of good entries in each column during the previous loop.

HTH, Malcolm