Would someone show me how to code this in pure Pytorch? Without loops of course. Have read about torch.gather and torch.index_select, but just can’t figure it out.

import torch
xcf = torch.rand((64,2,1024))
xcf.shape
mx,ix = xcf[:,1,:].max(dim=1) #Find the maximum value and index in the last dimension, middle column
ix.shape
# Here I want to index the ix position of the last dimension of xcf,
# returning a 64 x 2 matrix, M, where
# M[i,j] = xcf[i,j,ix[i]]