Your reply gave me the idea to use language_model_learner instead, passing it the DataBunch for the language model. But my code is not pretty… Looks like Frankenstein…
First I load the DataBunch from a DataBunch I created earlier for a language model:
data_lm = TextLMDataBunch.load(path, 'tmp_lm', bs=bs)
vocab = data_lm.vocab
Then I create a new language model learner, and when creating my own model (this model merge image data, tabular data and text data but this is a post for another time), I simply pass it learn.model:
learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.3)
model = CustomModel(learn.model)
Then inside of CustomModel, I extract the bits I care about from this model and change it for my purposes. encoder[0] contains the encoding part for the text. Then I use PoolingLinearClassifier to put everything to a size of 512 where I can concatenate that with other type of information (image and tabular) later on:
class CustomModel(nn.Module)
def __init__(self, encoder):
# this is the custom hidden size for AWD_LSTM
layers = [400 * 3] + [512]
ps = [.4]
self.lm_encoder = SequentialRNN(encoder[0], PoolingLinearClassifier(layers, ps))
Here is a more complete model merging Image, Tabular and Text… But it does’t take a long time before the loss becomes NaN… I think this might have to do with RNNTrainer that RNNLearner uses in it’s callbacks… But here is my complete model for anyone interested:
class ImageTabularModel(nn.Module):
"Basic model for tabular data."
def __init__(self, emb_szs:ListSizes, n_cont:int, layers:Collection[int], vocab_sz:int, encoder):
super().__init__()
self.cnn = create_body(models.resnet34)
self.tab = TabularModel(emb_szs, n_cont, 512, layers)
layers = [400 * 3] + [512]
ps = [.4]
self.lm_encoder = SequentialRNN(encoder[0], PoolingLinearClassifier(layers, ps))
self.reduce = nn.Sequential(*([Flatten()] + bn_drop_lin((512*7*7), 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True))))
self.merge = nn.Sequential(*bn_drop_lin(512 + 512 + 512, 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True)))
self.final = nn.Sequential(*bn_drop_lin(512, 2, bn=True, p=0., actn=nn.ReLU(inplace=True)))
def forward(self, img:Tensor, x:Tensor, text:Tensor) -> Tensor:
imgLatent = self.reduce(self.cnn(img))
tabLatent = self.tab(x[0], x[1])
textLatent = self.lm_encoder(text)[0]
cat = torch.cat([imgLatent, tabLatent, textLatent], dim=1)
return self.final(self.merge(cat))