I am having a problem that I cannot figure out how to fix, and probably anyone of you had already have to work around it.
I have 2 very separate ‘paths’ in my neural network. One path analyzes an image therefore it is a dilated convolutional network. The other path comes from a fully connected sequence (could be an LSTM too). The problem that I am having is that for some reason whenever I merge the output from those paths to have a prediction head, the network collapses. I tried almost everything, multiplicative joining (for conditioning), concatenation, summation, different dimensions (compression). Wheneven I try that, the network collapses to class 0.
It is very weird, but it is very consistent. I take down one of the paths, and the network starts behaving correctly, but it defeats the purpose which is conditioning the later part of the network with the output of the second path.
Yes that is the exact case, I have a convolutional and then an embedding. When I add the embedding the network collapses. In your example, what you think it was the problem and how you end up fixing it? Does AdaCos have anything to do with the solution?
No, the adacos sub-network is just a special implementation add-on in this case, so it is not needed.
The concatenation you find here (see comments):
# Here just the sub-components are setup
def __init__(self, body1, body2, head):
self.body1 = body1
self.body2 = body2
self.head = head
self.adacos = AdaCos(512, 1108)
def forward(self, xb, yb=None):
xb_img, xb_ctint, xb_pgint, xb_expint = xb
# first get the 3 different features:
img_feats1 = self.body1(xb_img[:,:6,...])
img_feats2 = self.body1(xb_img[:,6:,...])
int_feats = self.body2(xb_ctint, xb_pgint, xb_expint)
# concatenation happens then here:
feats = torch.cat((img_feats1, img_feats2, int_feats), dim=-1)
# after concat the new feature is run through the network head:
out = self.head(feats)
out = self.adacos(out, yb)
Best is you create instances of your feature extractor and test them before you put it into a new network class. Then you should have it easier to debug. Take care of the tensor dimensions (e.g., which dim you concat).