Pre-trained text encoder

If I just want to grab a pre-trained text-encoder to re-use in a different Module, what would be the best way of doing that?

For vision it is pretty straightforward, you just call and you have a pre-trained cnn you can use in your Module.

But for text, there’s get_text_classifier which returns a model, but the pre-trained weights are not loaded yet. There’s text_classifier_learner which use get_text_classifier and load pre-trained weights, but it needs a DataBunch to know the vocab and the final layer size and it returns a Learner, not a torch Module.

I would need something in between where I can just get a pre-trained language encoder I can re-use in a custom module.


Also RNNLearner has some callbacks like RNNTrainer. How could I integrate that into a custom module using a pre-trained language encoder.

The wikitext103 vocabulary is actually stored at ~/.fastai/models/wt103-1/itos_wt103.pkl and is used by RNNLearner's load_pretrained method.

What happens if you pass that to load_learner to pretend your databunch came from the wikitext103 dataset?

Learner has the torch nn.module inside at learn.model.


Your reply gave me the idea to use language_model_learner instead, passing it the DataBunch for the language model. But my code is not pretty… Looks like Frankenstein…

First I load the DataBunch from a DataBunch I created earlier for a language model:

data_lm = TextLMDataBunch.load(path, 'tmp_lm', bs=bs)
vocab = data_lm.vocab

Then I create a new language model learner, and when creating my own model (this model merge image data, tabular data and text data but this is a post for another time), I simply pass it learn.model:

learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.3)
model = CustomModel(learn.model)

Then inside of CustomModel, I extract the bits I care about from this model and change it for my purposes. encoder[0] contains the encoding part for the text. Then I use PoolingLinearClassifier to put everything to a size of 512 where I can concatenate that with other type of information (image and tabular) later on:

class CustomModel(nn.Module)
    def __init__(self, encoder):
        # this is the custom hidden size for AWD_LSTM
        layers = [400 * 3] + [512]
        ps = [.4]
        self.lm_encoder = SequentialRNN(encoder[0], PoolingLinearClassifier(layers, ps))

Here is a more complete model merging Image, Tabular and Text… But it does’t take a long time before the loss becomes NaN… I think this might have to do with RNNTrainer that RNNLearner uses in it’s callbacks… But here is my complete model for anyone interested:

class ImageTabularModel(nn.Module):
    "Basic model for tabular data."
    def __init__(self, emb_szs:ListSizes, n_cont:int, layers:Collection[int], vocab_sz:int, encoder):
        self.cnn = create_body(models.resnet34) = TabularModel(emb_szs, n_cont, 512, layers)

        layers = [400 * 3] + [512]
        ps = [.4]
        self.lm_encoder = SequentialRNN(encoder[0], PoolingLinearClassifier(layers, ps))

        self.reduce = nn.Sequential(*([Flatten()] + bn_drop_lin((512*7*7), 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True))))
        self.merge = nn.Sequential(*bn_drop_lin(512 + 512 + 512, 512, bn=True, p=0.5, actn=nn.ReLU(inplace=True))) = nn.Sequential(*bn_drop_lin(512, 2, bn=True, p=0., actn=nn.ReLU(inplace=True)))

    def forward(self, img:Tensor, x:Tensor, text:Tensor) -> Tensor:
        imgLatent = self.reduce(self.cnn(img))
        tabLatent =[0], x[1])
        textLatent = self.lm_encoder(text)[0]

        cat =[imgLatent, tabLatent, textLatent], dim=1)

1 Like

Note that if you want the model to read all your text, you should use text_classifier_learner to get an encoder that goes through the whole texts.

Hum indeed! Would be nice to have a simpler way than having to create a learner to then extract the interesting part of the model, the encoder.

I modified my code to use a different DataBunch that has a .c property. Because TextLMDataBunch doesn’t have one. So for anyone interested I used:

textList = TextList.from_df(pictures, cols='NameDefault', path=path, vocab=vocab)
data_text = textList.random_split_by_pct(.2).label_from_df(cols='Passed').databunch(bs=bs, num_workers=0) # used num_workers=0 because in windows, multiprocessing in jupyter is causing problems...

Also for this to train at all, I had to use RNNTrainer with my learner… Be advised that I don’t realy understand what RNNTrainer does… But the normal RNNTrainer used in RNNLearner seemed to create problem in the loss calculation methods with my custom model… So for now I created a really simple replacement just to see if I could get this thing to train at all…

class RNNTrainerSimple(LearnerCallback):
    "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
    def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):
        self.not_min += ['raw_out', 'out']
        self.alpha,self.beta = alpha,beta
    def on_epoch_begin(self, **kwargs):
        "Reset the hidden state of the model."

Then I had to make a custom Learner class to use this RNNTrainerSimple class:

class ImageTabularTextLearner(Learner):
    def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,
                 alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):
        super().__init__(data, model, **learn_kwargs)
        self.callbacks.append(RNNTrainerSimple(self, alpha=alpha, beta=beta))

Also I had to add a reset method in ImageTabularModel.

def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

This is training so far, the loss hasn’t been NaN yet…

Now my code is in dire need of a refactoring. I also need to work on a way to handle freeze/unfreeze, group layers to handle discriminative learning rates etc.

1 Like

Thanks, exactly same thing I am looking for! I am also coming from fastai vision and understand that encoder is the body, and we can just stick what ever the decoder (head) to the model as we need.

Also I am trying to figure out how to load different encoders, such as BERT/GPT-2 into fastai. Any input is highly appreciated!


I made a new thread on this very topic and linked in a good article showing how to do it.

1 Like


You just emptied my afternoon schedule. I will start reading and playing with it.


1 Like