I am trying to figure out the best way to use the split functionalities of fastai. I have a model merging image data, tabular data and variable length text data. I use a pre-trained cnn for the image, pre-trained language model for the text bit. I want to split my model to be able to use freeze and use discriminative learning rates…
Here is my model:
class ImageTabularModel(nn.Module):
"Basic model for tabular data."
def __init__(self, emb_szs:ListSizes, n_cont:int, layers:Collection[int], vocab_sz:int, encoder):
super().__init__()
self.cnn = create_body(models.resnet34)
layers = [400 * 3] + [512]
ps = [.4]
self.lm_encoder = SequentialRNN(encoder[0], PoolingLinearClassifier(layers, ps))
self.tab = TabularModel(emb_szs, n_cont, 512, layers)
self.reduce = nn.Sequential(*([Flatten()] + bn_drop_lin((512*7*7), 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True))))
self.merge = nn.Sequential(*bn_drop_lin(512 + 512 + 512, 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True)))
self.final = nn.Sequential(*bn_drop_lin(512, 2, bn=True, p=0., actn=nn.ReLU(inplace=True)))
def forward(self, img:Tensor, x:Tensor, text:Tensor) -> Tensor:
imgLatent = self.reduce(self.cnn(img))
tabLatent = self.tab(x[0], x[1])
textLatent = self.lm_encoder(text)[0]
cat = torch.cat([imgLatent, tabLatent, textLatent], dim=1)
return self.final(self.merge(cat))
def reset(self):
for c in self.children():
if hasattr(c, 'reset'): c.reset()
And then here is my code to split this into multiple layer_groups…
def split_layers(model:nn.Module) -> List[nn.Module]:
groups = [[model.cnn, model.lm_encoder]]
groups += [[model.tab, model.reduce, model.merge, model.final]]
return groups
I pass this split function to my learner. I tried to copy the same kind of logic from the split examples in fastai code, but not sure I am doing it right.
From what I understand, Learner.freeze freeze all layer groups except last one. I just want to make sure I take full advantage of this.
Thanks!