The comment applies to the following lines of code:

convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)

xxT = torch.bmm(x,x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)

o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)

Originally we were doing operations in this order (Note that conv(x) is analogue to a matrix multiplication W*x in this case, where W has dimension (C,C))

x * (x^T * (conv(x)))

- conv(x) (dims: (C,C) and (C,N))
- x^T * (conv(x)) (dims: (N,C) and (C,N))
- x * (x^T * (conv(x))) (dims: (C,N) and (N,N))

This is the naive/“natural” order of implementing those operations.

Check out the complexity of matrix multiplication: https://en.wikipedia.org/wiki/Computational_complexity_of_mathematical_operations#Matrix_algebra

Complexity of those 3 operations:

- O(C^2*N)
- O(N^2*C)
- O( C* N^2)

Now, unless we increase channels a lot, we mainly have an issue with complexity that are proportional to N^2. This is because N= H*W. So if you double image size, you increase complexity by 2^4.

By changing the order of operations to (x*xT)*(W*x), we do:

- convx = conv(x)
- xxT = x*xT
- o = xxT * convx

And, as commented in the code at the top of this post, those 3 operations are O(NC^2), which means that run time is much less sensitive to image size.

Let me know if you have any other questions.