Thanks for the advice. I have made some progress with regard to custom head modifications.
I simply added a sigmoid to the existing head. I first determined what layers were contained in the head by:
learn.model[1]
So my head for a densenet121 model with sigmoid added looks like:
head = nn.Sequential(
AdaptiveConcatPool2d(),
Flatten(),
nn.BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Dropout(p=0.25),
nn.Linear(in_features=2048, out_features=512, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Dropout(p=0.5),
nn.Linear(in_features=512, out_features=1, bias=True),
nn.Sigmoid()
)
The model trains much more quickly with this addition and no longer predicts values below zero, though the performance after training isn’t much better.
I did find a kaggle example where someone implemented the scaled sigmoid in the forward : https://www.kaggle.com/rasmus01610/nih-chest-x-ray-age
From their notebook:
class AgeModel(nn.Module):
def __init__(self):
super().__init__()
layers = list(models.resnet34().children())[:-2]
layers += [AdaptiveConcatPool2d(), Flatten()]
layers += [nn.Linear(1024,16), nn.ReLU(), nn.Linear(16,1)]
self.agemodel = nn.Sequential(*layers)
def forward(self, x):
x = self.agemodel(x).squeeze()
return torch.sigmoid(x) * (max_age - min_age) + min_age
I suspect that
layers = list(models.resnet34().children())[:-2]
Removes the last two layers of the head of the resnet, and then they add a few more of their own. I imagine that the additions are specific to their application, so I tried to reproduce without removing the last two layers, simply copy-pasting the densenet
layers and modifying the forward as follows:
class CoverModel(nn.Module):
def __init__(self):
super().__init__()
layers = list(models.densenet121().children())
self.covermodel = nn.Sequential(*layers)
def forward(self, x):
x = self.covermodel(x).squeeze()
return torch.sigmoid(x) * (max_cover - min_cover) + min_cover
Setting the arch to CoverModel() throws an error, however, so still some work to be done. Thanks for your help.