Transfer Learning with Previously Trained Language Model

Hello, I am trying to do text classification. I have a reasonably large amount of unlabeled data and a considerably smaller amount of labeled data.

I have trained a language model on the larger data, and would like to use it to train a text classifier on the labeled data.

(Note: I saved the language model I want to use with

However, when I run the following code:

data_lm = TextLMDataBunch.from_folder(DATA_PATH)
data_clas = TextClasDataBunch.from_folder(DATA_PATH,vocab=data_lm.train_ds.vocab, bs=45)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, wd=1e-1)
learn.load_encoder('large_lm1_enc') # previously trained language model

I get the following error:

RuntimeError: Error(s) in loading state_dict for AWD_LSTM:
	Missing key(s) in state_dict: "encoder.weight", "encoder_dp.emb.weight", "rnns.0.weight_hh_l0_raw", "rnns.0.module.weight_ih_l0", "rnns.0.module.weight_hh_l0", "rnns.0.module.bias_ih_l0", "rnns.0.module.bias_hh_l0", "rnns.1.weight_hh_l0_raw", "rnns.1.module.weight_ih_l0", "rnns.1.module.weight_hh_l0", "rnns.1.module.bias_ih_l0", "rnns.1.module.bias_hh_l0", "rnns.2.weight_hh_l0_raw", "rnns.2.module.weight_ih_l0", "rnns.2.module.weight_hh_l0", "rnns.2.module.bias_ih_l0", "rnns.2.module.bias_hh_l0". 
	Unexpected key(s) in state_dict: "model", "opt". 

I am not entirely sure what the problem is. If I am going about this the wrong way entirely, how should I do this (ideally sticking withing the library)?

Hi @acombandrew

I’m working on a similar thing - below is the link to my solution on google colab, it’s working without errors. Do you do something differently?



Continuing the discussion from Transfer Learning with Previously Trained Language Model:

Hello @darek.kleczek thank you for the response. I looked through your code. Is the data you trained the language model on the same as you trained the classifier on?
I think the problem I’m getting is related to the fact that I’m training the language model on different data.

Alternatively it could be me using instead of learn.save_encoder. I will try this fix and note if it works or not.

I figured out the part of my problem that was relevant to this post- I was using save() instead of save_encoder(), which I wouldn’t have noticed if I didn’t have code to compare with. Thanks!

1 Like