Load pre-trained weights (in AWD_LSTM) from other model

I first trained a language model, then trained an encoder+classifier. Now I want to load these encoder weights in an ‘encoder-only’ model. i.e.

The encoder of the pre-trained model is saved as hi_enc.

arch = AWD_LSTM
encoder = arch(vocab_sz=len(data_lm.vocab.itos), **config_en)

I tried - encoder.load('hi_enc')
which is obviously not working, because AWD_LSTM doesn’t have this function ‘load’. Can someone please point me to the function that should let me load the pre-trained weights! TIA

[EDIT] OR if I can ‘split’ my original model, such that I only get the encoder layers (no attached classifier or decoder)

How did you save your weights? If you used the torch functions, you cna load them back with encoder.load_state_dict(torch.load(fname)).

2 Likes

Thanks this is working!
I used the usual learn.save_encoder('fine_tuned_enc') to save the encoder of the language model.

Instead of creating an encoder via AWD_LSTM, I am now using

encoder_mul = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config_mul), pad_idx=pad_idx)
encoder_mul.load_state_dict(torch.load('fine_tuned_enc.pth')

When I load the pre-trained weights into a MutliBatchEncoder I get the following error -

Error(s) in loading state_dict for MultiBatchEncoder: Missing key(s) in state_dict: "module.encoder.weight", "module.encoder_dp.emb.weight", "module.rnns.0.weight_hh_l0_raw", "module.rnns.0.module.weight_ih_l0", "module.rnns.0.module.weight_hh_l0", "module.rnns.0.module.bias_ih_l0", "module.rnns.0.module.bias_hh_l0", "module.rnns.1.weight_hh_l0_raw", "module.rnns.1.module.weight_ih_l0", "module.rnns.1.module.weight_hh_l0", "module.rnns.1.module.bias_ih_l0", "module.rnns.1.module.bias_hh_l0", "module.rnns.2.weight_hh_l0_raw", "module.rnns.2.module.weight_ih_l0", "module.rnns.2.module.weight_hh_l0", "module.rnns.2.module.bias_ih_l0", "module.rnns.2.module.bias_hh_l0". Unexpected key(s) in state_dict: "encoder.weight", "encoder_dp.emb.weight", "rnns.0.weight_hh_l0_raw", "rnns.0.module.weight_ih_l0", "rnns.0.module.weight_hh_l0", "rnns.0.module.bias_ih_l0", "rnns.0.module.bias_hh_l0", "rnns.1.weight_hh_l0_raw", "rnns.1.module.weight_ih_l0", "rnns.1.module.weight_hh_l0", "rnns.1.module.bias_ih_l0", "rnns.1.module.bias_hh_l0", "rnns.2.weight_hh_l0_raw", "rnns.2.module.weight_ih_l0", "rnns.2.module.weight_hh_l0", "rnns.2.module.bias_ih_l0", "rnns.2.module.bias_hh_l0".

PyTorch requires the names to match exactly, your weights match up just that they have differing names. A naive approach I once took to solve this problem was:

wgts = torch.load(**State dict file**)
params = list(zip(wgts.items(),learn.model.state_dict().items()))

for p in params:
    name = p[1][0]
    learn.model.state_dict()[name] = p[0][1]
1 Like