Would someone show me how to code this in pure Pytorch? Without loops of course. Have read about torch.gather and torch.index_select, but just can’t figure it out.
import torch xcf = torch.rand((64,2,1024)) xcf.shape mx,ix = xcf[:,1,:].max(dim=1) #Find the maximum value and index in the last dimension, middle column ix.shape # Here I want to index the ix position of the last dimension of xcf, # returning a 64 x 2 matrix, M, where # M[i,j] = xcf[i,j,ix[i]]
Thanks for helping me pass this obstacle!