How to fix 'Error(s) in loading state_dict for AWD_LSTM' when using fast-ai

(George K.) #1

I am using the fast-ai library in order to train a sample of the IMDB reviews dataset. I want to use it for sentiment analysis and I have trained the model in a VM by using this tutorial.

I saved the data_lm and data_clas models, then the encoder ft_enc and after that I saved the classifier learner sentiment_model. I, then, got those 4 files from the VM and put them in my machine and wanted to use those pretrained models in order to classify sentiment.

This is what I did:

# Use the IMDB_SAMPLE file
path = untar_data(URLs.IMDB_SAMPLE)

# Language model data
data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')

# Sentiment classifier model data
data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', 
                                       vocab=data_lm.train_ds.vocab, bs=32)

# Build a classifier using the tuned encoder (tuned in the VM)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)

# Load the trained model

After that, I wanted to use that model in order to predict the sentiment of a sentence. When executing this code, I ran into the following error:

RuntimeError: Error(s) in loading state_dict for AWD_LSTM:
   size mismatch for encoder.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]).
   size mismatch for encoder_dp.emb.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]). 

And the Traceback is:

Traceback (most recent call last):
  File "C:/Users/user/PycharmProjects/SentAn/", line 51, in <module>
    learn = load_models()
  File "C:/Users/user/PycharmProjects/SentAn/", line 32, in load_models
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\fastai\text\", line 68, in load_encoder
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\torch\nn\modules\", line 769, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))

So the error occurs when loading the encoder. But, I also tried to remove the load_encoder line but the same error occurred at the next line learn.load('sentiment_model').

I searched through the fast-ai forum and noticed that others also had this issue but found no solution. In this post the user says that this might have to do with different preprocessing, though I couldn’t understand why this would happen.

Could anyone please help me with that issue. I have tried many different things but I cannot find something to get past this. Any help would be appreciated.


(Max Tian) #2

I’ve encountered the same issue. I suspect it’s due to the validation split, different data kept in the training set end up resulting in different vocabs when tokenized. You can try to keep the training and validation data the same, see if that works.

However this brings up a problem for using the fastai library. We can’t realistically assume that our vocab for the language model and classifier to always be the same can we, or do we have to always use the exact same training/validation data if we want to load in a network we’ve trained?



I had exactly the same problem. It is extremely frustrating. What is the purpose of the method at all if learner.load fails, because of some discrepancy outside of the learner? It seems like export is the only way to have model persistence from one session to the next.

Isn’t there any way to simply lock down the vocabulary that was used when training the model, or even just locking down the exact train/validation split that was used in the original training data bunch? I’m a newcomer to the API… Because so many of the implementation details are abstracted away from the user, I don’t know how to even begin solving this problem.


(Md Abul Bashar) #4

It seems vocabulary size of data_clas and data_lm are different. I guess the problem is caused by different preprocessing used in data_clas and data_lm. To check my guess I simply used

data_clas.vocab.itos = data_lm.vocab.itos

Before the following line

learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.3)

This has fixed the error.