I found a good ol’ stage by stage print was helpful to better understand broadcasting matmult:
def matmul(a,b):
ar,ac = a.shape
br,bc = b.shape
assert ac==br
c = torch.zeros(ar, bc)
print(a, "a")
print(b, "b")
for i in range(ar):
# c[i,j] = (a[i,:] * b[:,j]).sum() # previous
c[i] = (a[i ].unsqueeze(-1) * b).sum(dim=0)
print(f"\na[{i}].unsqueeze(-1)")
print(a[i ].unsqueeze(-1),"\n")
print(a[i ].unsqueeze(-1).expand_as(b))
print(b)
print(c[i])
return c
m1 = tensor([[1., 1., 1.],
[2., 2., 2.]])
matmul(m1, m1.t())
gives:
tensor([[1., 1., 1.],
[2., 2., 2.]]) a
tensor([[1., 2.],
[1., 2.],
[1., 2.]]) b
a[0].unsqueeze(-1)
tensor([[1.],
[1.],
[1.]])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[1., 2.],
[1., 2.],
[1., 2.]])
tensor([3., 6.])
a[1].unsqueeze(-1)
tensor([[2.],
[2.],
[2.]])
tensor([[2., 2.],
[2., 2.],
[2., 2.]])
tensor([[1., 2.],
[1., 2.],
[1., 2.]])
tensor([ 6., 12.])
Out[176]:
tensor([[ 3., 6.],
[ 6., 12.]])
so you can easily see how a[i]
is transformed into a matrix and then the element-wise multiplication, followed by sum leading to c[i]
is easy to see (last 3 tensor printouts of each loop).
And further to understand: c[None] > c[:,None]
x = c[None,:] * c[:,None]
c[None].expand_as(x)
c[:,None].expand_as(x)
c[None] > c[:,None]
gives easy to understand:
tensor([[10., 20., 30.],
[10., 20., 30.],
[10., 20., 30.]])
tensor([[10., 10., 10.],
[20., 20., 20.],
[30., 30., 30.]])
tensor([[0, 1, 1],
[0, 0, 1],
[0, 0, 0]], dtype=torch.uint8)
(I have jupyter configured to print all outputs, not just the last one - handy!)