That simple solution helped a lot. Thanks!
All this got me quite far… up to the point where no errors arise, but yet -
the network doesn’t train (even on a small dataset).
Here’re my advancements after @florobax’s last hint:
- After doing
learn.model = model, I’ve got an error of the sort: “x.orig doesn’t exist… in Merge layer”. So, after having a glimpse inside the SequentialEx class (discussed here too), I manually implanted that “orig” member in each of the layers I pass through. - The loss function correctly accepts both types (segmentation + classification) of predictions & targets. I’ve checked that the images & masks correctly match, and that the loss calculation is correct.
But… - Since the loss values don’t go down when trying to train, I’ve tried monitoring the mean & std of the layers along the way, as is demonstrated in this fastai course notebook.
What I found was quite interesting: The means and stds of the layers activations were quite regular (though somewhat static, see left image below), but the batchnorm layer was crazy, having the value of “3.0” in the 1st batch, and “0” in the 2nd and 3rd batches, going on like that cyclically (see blue line in right image below). Anyone familiar with such a behavior??
.
Now all this makes quite a new question that probably shouldn’t appear under this topic. But there is a chance that it actually is directly involved in the “branching out from the middle of a unet…” effort. So I’ll relocate this reply depending on the answer I find for this issue.
Anyway, here’s how the updated model looks like:
a. Architecture described by “learn.model”:
MultiTaskUnetResnet(
(unet_resnet34): ... <identical to the original unet_res34 model>
(clf_layer): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=(1, 2))
(mp): AdaptiveMaxPool2d(output_size=(1, 2))
)
(fc): Linear(in_features=2048, out_features=3, bias=True)
)
b. My code for the architecture:
class MultiTaskUnetResnet(nn.Module):
def __init__(self,pretrained_unet_resnet34):
super(MultiTaskUnetResnet, self).__init__()
self.unet_resnet34=pretrained_unet_resnet34`
self.clf_layer = AdaptiveConcatPool2d((1,2))
self.fc = nn.Linear(2*1024, 3)
def forward(self, x):
res = x
res.orig = x
nres = self.unet_resnet34[0](res)
res.orig = None
res = nres
res_clf = res.clone()
res_clf = self.clf_layer(res_clf)
res_clf = self.fc(res_clf.view(res_clf.shape[0],-1))
for l in (self.unet_resnet34[1:]):
res.orig = x
nres = l(res)
res.orig = None
res = nres
return res, res_clf
learn.model = MultiTaskUnetResnet(learner.model).cuda()