Pre-trained text encoder

Hum indeed! Would be nice to have a simpler way than having to create a learner to then extract the interesting part of the model, the encoder.

I modified my code to use a different DataBunch that has a .c property. Because TextLMDataBunch doesn’t have one. So for anyone interested I used:

textList = TextList.from_df(pictures, cols='NameDefault', path=path, vocab=vocab)
data_text = textList.random_split_by_pct(.2).label_from_df(cols='Passed').databunch(bs=bs, num_workers=0) # used num_workers=0 because in windows, multiprocessing in jupyter is causing problems...

Also for this to train at all, I had to use RNNTrainer with my learner… Be advised that I don’t realy understand what RNNTrainer does… But the normal RNNTrainer used in RNNLearner seemed to create problem in the loss calculation methods with my custom model… So for now I created a really simple replacement just to see if I could get this thing to train at all…

class RNNTrainerSimple(LearnerCallback):
    "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
    def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):
        super().__init__(learn)
        self.not_min += ['raw_out', 'out']
        self.alpha,self.beta = alpha,beta
        
    def on_epoch_begin(self, **kwargs):
        "Reset the hidden state of the model."
        self.learn.model.reset()

Then I had to make a custom Learner class to use this RNNTrainerSimple class:

class ImageTabularTextLearner(Learner):
    def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,
                 alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):
        super().__init__(data, model, **learn_kwargs)
        self.callbacks.append(RNNTrainerSimple(self, alpha=alpha, beta=beta))

Also I had to add a reset method in ImageTabularModel.

def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

This is training so far, the loss hasn’t been NaN yet…

Now my code is in dire need of a refactoring. I also need to work on a way to handle freeze/unfreeze, group layers to handle discriminative learning rates etc.

1 Like