Interesting find for me that simply decreasing the number of loops in my own implementation of matmul did not necessarily improve performance. I was able to use broadcasting to remove ALL loops in python, but still got performance many order of magnitudes less than the torch.matmul. So when testing use a few different input values, including large ones and timeit!
def my_matmul(m1,m2):
return (m1[...,None]*m2[:][:,None]).sum(-2)
m1=torch.rand([10,56,85])
m2=torch.rand([10,85,56])
%timeit torch.matmul(m1,m2) #88.3 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
print(torch.matmul(m1,m2))
%timeit my_matmul(m1,m2) 3.06 ms ± 28.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
print(my_matmul(m1,m2))
print(m1[…,None].shape) #torch.Size([10, 56, 85, 1])
print(m2[:][:,None].shape) #torch.Size([10, 1, 85, 56])
My understanding is that this is because I am basically multiplying single element by the second [85,86] array. I thought I was done with this lesson until I actually timed my own version!