How do you code this in PyTorch?

(Malcolm McLean) #1

Would someone show me how to code this in pure Pytorch? Without loops of course. Have read about torch.gather and torch.index_select, but just can’t figure it out.

import torch

xcf = torch.rand((64,2,1024))

mx,ix = xcf[:,1,:].max(dim=1) #Find the maximum value and index in the last dimension, middle column

# Here I want to index the ix position of the last dimension of xcf,
# returning a 64 x 2 matrix, M, where
# M[i,j] = xcf[i,j,ix[i]]

Thanks for helping me pass this obstacle!

(youali) #2


So from what I understood, you want to get the elements of the first dimensions using the index of the last one,

try this:

m1 = torch.gather(xcf[:, 0, :], dim=1, index=ix.unsqueeze(1)) 
m2 = torch.gather(xcf[:, 1, :], dim=1, index=ix.unsqueeze(1))
m =, m2), dim=1)

m.shape #torch.Size([64, 2])
m[10, 1] == xcf[10,1,ix[10]]  #tensor(1, dtype=torch.uint8)

(Malcolm McLean) #3

Thanks! Exactly what I needed.