Hum indeed! Would be nice to have a simpler way than having to create a learner to then extract the interesting part of the model, the encoder.
I modified my code to use a different DataBunch that has a .c property. Because TextLMDataBunch doesn’t have one. So for anyone interested I used:
textList = TextList.from_df(pictures, cols='NameDefault', path=path, vocab=vocab)
data_text = textList.random_split_by_pct(.2).label_from_df(cols='Passed').databunch(bs=bs, num_workers=0) # used num_workers=0 because in windows, multiprocessing in jupyter is causing problems...
Also for this to train at all, I had to use RNNTrainer with my learner… Be advised that I don’t realy understand what RNNTrainer does… But the normal RNNTrainer used in RNNLearner seemed to create problem in the loss calculation methods with my custom model… So for now I created a really simple replacement just to see if I could get this thing to train at all…
class RNNTrainerSimple(LearnerCallback):
"`Callback` that regroups lr adjustment to seq_len, AR and TAR."
def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):
super().__init__(learn)
self.not_min += ['raw_out', 'out']
self.alpha,self.beta = alpha,beta
def on_epoch_begin(self, **kwargs):
"Reset the hidden state of the model."
self.learn.model.reset()
Then I had to make a custom Learner class to use this RNNTrainerSimple class:
class ImageTabularTextLearner(Learner):
def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,
alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):
super().__init__(data, model, **learn_kwargs)
self.callbacks.append(RNNTrainerSimple(self, alpha=alpha, beta=beta))
Also I had to add a reset method in ImageTabularModel.
def reset(self):
for c in self.children():
if hasattr(c, 'reset'): c.reset()
This is training so far, the loss hasn’t been NaN yet…
Now my code is in dire need of a refactoring. I also need to work on a way to handle freeze/unfreeze, group layers to handle discriminative learning rates etc.