Lesson 8 (2019) discussion & wiki

I found a good ol’ stage by stage print was helpful to better understand broadcasting matmult:

def matmul(a,b):
    ar,ac = a.shape
    br,bc = b.shape
    assert ac==br
    c = torch.zeros(ar, bc)
    print(a, "a")
    print(b, "b")
    for i in range(ar):
#       c[i,j] = (a[i,:]          * b[:,j]).sum() # previous
        c[i]   = (a[i  ].unsqueeze(-1) * b).sum(dim=0)
        print(f"\na[{i}].unsqueeze(-1)")
        print(a[i  ].unsqueeze(-1),"\n")
        print(a[i  ].unsqueeze(-1).expand_as(b))
        print(b)
        print(c[i])
    return c

m1 = tensor([[1., 1., 1.],
             [2., 2., 2.]])
matmul(m1, m1.t())

gives:

tensor([[1., 1., 1.],
        [2., 2., 2.]]) a
tensor([[1., 2.],
        [1., 2.],
        [1., 2.]]) b

a[0].unsqueeze(-1)
tensor([[1.],
        [1.],
        [1.]]) 

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[1., 2.],
        [1., 2.],
        [1., 2.]])
tensor([3., 6.])

a[1].unsqueeze(-1)
tensor([[2.],
        [2.],
        [2.]]) 

tensor([[2., 2.],
        [2., 2.],
        [2., 2.]])
tensor([[1., 2.],
        [1., 2.],
        [1., 2.]])
tensor([ 6., 12.])
Out[176]:
tensor([[ 3.,  6.],
        [ 6., 12.]])

so you can easily see how a[i] is transformed into a matrix and then the element-wise multiplication, followed by sum leading to c[i] is easy to see (last 3 tensor printouts of each loop).

And further to understand: c[None] > c[:,None]

x = c[None,:] * c[:,None]
c[None].expand_as(x)
c[:,None].expand_as(x)
c[None] > c[:,None]

gives easy to understand:

tensor([[10., 20., 30.],
        [10., 20., 30.],
        [10., 20., 30.]])
tensor([[10., 10., 10.],
        [20., 20., 20.],
        [30., 30., 30.]])
tensor([[0, 1, 1],
        [0, 0, 1],
        [0, 0, 0]], dtype=torch.uint8)

(I have jupyter configured to print all outputs, not just the last one - handy!)

8 Likes