Saving model from fastai and load it in pytorch


Has anyone got a working example how to load a fastai model into pytorch, to use it for prediction?

I would like to be able to train my model in fastai, but then just load it with pytorch for prediction without having to install fastai.



Hi, I (we) are wondering the same thing! You can load the .h5 file by calling torch.load but we have no idea where to go from there.

I can’t believe this crucial step is so overlooked both in the notebooks and in this community.



I have been looking into this myself, I think you’ll find that this community is more focused on the perfection and R&D of the algorithms than the deployment.

If anyone has any tutorials on deploying these models that’d be great.


(Taman) #4

Hi, for images I got it to work quite easily.

Save your model with
Then in a new file / notebook recreate your model architecture and create an instance of your model.
E.g. : m = Model()

Then do:
state = torch.load(‘path_to_saved_model.h5’)

And finally:

You can then predict by m(input).

Now, for the input, you ll need to preprocess it in the same way as the library and to transform it to a Variable. It takes a bit of work but is totally doable.



Thanks for reply, I am trying to apply it to the model saved from “cats vs dogs” in lesson 1:

import torch
import torchvision.models as models

resnet34 = models.resnet34()
state = torch.load('224_all.h5')


and it fails in load_state_dict with the error message:

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/ in load_state_dict(self, state_dict, strict)
    488             elif strict:
    489                 raise KeyError('unexpected key "{}" in state_dict'
--> 490                                .format(name))
    491         if strict:
    492             missing = set(own_state.keys()) - set(state_dict.keys())

KeyError: 'unexpected key "0.weight" in state_dict'

is it not the same kind of model? A different kind of resnet34 perhaps?



It looks the names for the keys in the state_dict get lost, after some digging I get this

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2...
	Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "4.0.conv1.weight", "4.0.bn1.weight", "4.0.bn1.bias", "4.0.bn1.running_mean", "4.0.bn1.running_var", "4.0.conv2.weight", "4.0.bn2.weight", "4.0.bn2.bias", "4.0.bn2.running_mean", "4.0.bn2.running_var", "4.1.conv1.weight", "4.1.bn1.weight", "4.1.bn1.bias", "4.1.bn1.running_mean", "4.1.bn1.running_var", "4.1.conv2.weight", "4.1.bn2.weight", "4.1.bn2.bias", "4.1.bn2.running_mean", "4.1.bn2.running_var", "4.2.conv1.weight", "4.2.bn1.weight", "4.2.bn1.bias", "4.2.bn1.running_mean", "4.2.bn1.running_var", "4.2.conv2.weight", "4.2.bn2.weight", "4.2.bn2.bias", "4.2.bn2.running_mean", "4.2.bn2.running_var", "5.0.conv1.weight", "5.0.bn1.weight", "5.0.bn1.bias", "5.0.bn1.running_mean", "5.0.bn1.running_var", "5.0.conv2.weight", "5.0.bn2.weight", "5.0.bn2.bias", "5.0.bn2.running_mean", "5.0.bn2.running_var", "5.0.downsample.0.weight", "5.0.downsample.1.weight", "5.0.downsample.1.bias", "5.0.downsample.1.running_mean", "5.0.downsample.1.running_var", "5.1.conv1.weight", "5.1.bn1.weight", "5.1.bn1.bias", "5.1.bn1.running_mean", "5.1.bn1.running_var", "5.1.conv2.weight", "5.1.bn2.weight", "5.1.bn2.bias", "5.1.bn2.running_mean", "5.1.bn2.running_var", "5.2.conv1.weight", "5.2.bn1.weight", "5.2.bn1.bias", "5.2.bn1.running_mean", "5.2.bn1.running_var", "5.2.conv2.weight", "5.2.bn2.weight", "5.2.bn2.bias", "5.2.bn2.running_mean", "5.2.bn2.running_var", "5.3.conv1.weight", "5.3.bn1.weight", ...

Why might this be? Its literally the example from lesson1, so the resnet model has not been modified in any way. Any ideas?

If I dump

in the lesson, before saving it, it also only got keys like “0.weight”


(Manoj Raman Kondabathula) #8

Hello @chloe2018, did you find any solution? Even I am facing same issuse.

1 Like


sadly not, i still have the fastai dependency :frowning:


(Manoj Raman Kondabathula) #10

I made a new post yesterday, How to load fastai saved model with Pytorch-cpu. Also taged Mr Jeremy… His reply might ve useful. He didn’t reply yet.


(Mario) #11

The problem is, that the pyTorch version of the ResNet has the same backbone as the fastai version, but the head is different.

You could use the following code from the create_cnn method, to create the same net

a model created by

learn = create_cnn(data, models.resnet50, metrics=[accuracy])'final_model')

can afterwards be loaded with

from fastai import *
from import *
import torch
loc = torch.load(base_path/'models/final_model.pth')
body = create_body(models.resnet50, True, None)
data_classes = 2
nf = callbacks.hooks.num_features_model(body) * 2
head = create_head(nf, data_classes, None, ps=0.5, bn_final=False)
model = nn.Sequential(body, head)

created the same network architecture as the one saved before


loads the saved weights from the initial model.
Hope this helps.


(Dhaval Mayatra) #12

This helped me.
Thanks a lot.