The comment applies to the following lines of code:
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
xxT = torch.bmm(x,x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
Originally we were doing operations in this order (Note that conv(x) is analogue to a matrix multiplication W*x in this case, where W has dimension (C,C))
x * (x^T * (conv(x)))
- conv(x) (dims: (C,C) and (C,N))
- x^T * (conv(x)) (dims: (N,C) and (C,N))
- x * (x^T * (conv(x))) (dims: (C,N) and (N,N))
This is the naive/“natural” order of implementing those operations.
Check out the complexity of matrix multiplication: Computational complexity of mathematical operations - Wikipedia
Complexity of those 3 operations:
- O(C^2*N)
- O(N^2*C)
- O( C* N^2)
Now, unless we increase channels a lot, we mainly have an issue with complexity that are proportional to N^2. This is because N= H*W. So if you double image size, you increase complexity by 2^4.
By changing the order of operations to (xxT)(W*x), we do:
- convx = conv(x)
- xxT = x*xT
- o = xxT * convx
And, as commented in the code at the top of this post, those 3 operations are O(NC^2), which means that run time is much less sensitive to image size.
Let me know if you have any other questions.