Simply decreasing loop count does not necissarily increase performance

Interesting find for me that simply decreasing the number of loops in my own implementation of matmul did not necessarily improve performance. I was able to use broadcasting to remove ALL loops in python, but still got performance many order of magnitudes less than the torch.matmul. So when testing use a few different input values, including large ones and timeit!

def my_matmul(m1,m2): return (m1[...,None]*m2[:][:,None]).sum(-2)

m1=torch.rand([10,56,85])
m2=torch.rand([10,85,56])

%timeit torch.matmul(m1,m2) #88.3 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
print(torch.matmul(m1,m2))
%timeit my_matmul(m1,m2) 3.06 ms ± 28.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
print(my_matmul(m1,m2))

print(m1[…,None].shape) #torch.Size([10, 56, 85, 1])
print(m2[:][:,None].shape) #torch.Size([10, 1, 85, 56])

My understanding is that this is because I am basically multiplying single element by the second [85,86] array. I thought I was done with this lesson until I actually timed my own version!

Sorry for not addressing your question, but I think there’s a bug in your implementation:

def my_matmul(m1,m2): return (m1[...,None]*m2[:][:,None]).sum(-2)
m1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
m2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
test_near(m1.matmul(m2), my_matmul(m1, m2))

AssertionError: near:
tensor([[ 30, 36, 42],
[ 66, 81, 96],
[102, 126, 150]])
tensor([[ 6, 12, 18],
[ 60, 75, 90],
[168, 192, 216]])

My tensors include batchsize as well, so it is a rank 3 tensor.

You can find a full notebook here: https://github.com/marii-moe/basics/blob/master/Matrix%20Multiply.ipynb