How to load fastai saved model with Pytorch-cpu

Hi,
I am working on cloud classification and at the first stage I made a ResNet34 model using fastai to detect the presence of clouds and saved the model. I got reasonably good results and now I want to use it directly on the imaging system which does not have GPU. How can I load the model on to the respective computer and perform predictions on one image at a time using this model.

I am using following code

loc = '*:/****/****/****/224_nlc_all.h5'
model = resnet34(pretrained=False)
state = load(loc,device('cpu'))
model.load_state_dict(state)
model.eval()
image = Image.open(image_loc)
img_tensor = to_tensor(image)
print(model(img_tensor))

At:

model.load_state_dict(state)

I am getting the following error:

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict:...................

Thank you for your time.

2 Likes

I think “img_tensor” might be of wrong dimension and does the image need transformation before being fed to the model? If so, can you please say what I should take care of.

Thank you

Hello @jeremy, sorry for disturbance. Can you please give any suggestions regarding above issue.

Thank you for your time

Can you show the output of state.keys() and model.state_dict().keys() ?

I think is a simple naming issue, had a problem like that before as well. Also make sure to always set precompute to false before saving the model.

Cheers

Dear @j.laute, thank you for your response. Sorry for my late reply, I was out of the country during the weekend :slight_smile:

Here are the info you asked for:

state.keys()
odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked', '4.0.conv1.weight', '4.0.bn1.weight', '4.0.bn1.bias', '4.0.bn1.running_mean', '4.0.bn1.running_var', '4.0.bn1.num_batches_tracked', '4.0.conv2.weight', '4.0.bn2.weight', '4.0.bn2.bias', '4.0.bn2.running_mean', '4.0.bn2.running_var', '4.0.bn2.num_batches_tracked', '4.1.conv1.weight', '4.1.bn1.weight', '4.1.bn1.bias', '4.1.bn1.running_mean', '4.1.bn1.running_var', '4.1.bn1.num_batches_tracked', '4.1.conv2.weight', '4.1.bn2.weight', '4.1.bn2.bias', '4.1.bn2.running_mean', '4.1.bn2.running_var', '4.1.bn2.num_batches_tracked', '4.2.conv1.weight', '4.2.bn1.weight', '4.2.bn1.bias', '4.2.bn1.running_mean', '4.2.bn1.running_var', '4.2.bn1.num_batches_tracked', '4.2.conv2.weight', '4.2.bn2.weight', '4.2.bn2.bias', '4.2.bn2.running_mean', '4.2.bn2.running_var', '4.2.bn2.num_batches_tracked', '5.0.conv1.weight', '5.0.bn1.weight', '5.0.bn1.bias', '5.0.bn1.running_mean', '5.0.bn1.running_var', '5.0.bn1.num_batches_tracked', '5.0.conv2.weight', '5.0.bn2.weight', '5.0.bn2.bias', '5.0.bn2.running_mean', '5.0.bn2.running_var', '5.0.bn2.num_batches_tracked', '5.0.downsample.0.weight', '5.0.downsample.1.weight', '5.0.downsample.1.bias', '5.0.downsample.1.running_mean', '5.0.downsample.1.running_var', '5.0.downsample.1.num_batches_tracked', '5.1.conv1.weight', '5.1.bn1.weight', '5.1.bn1.bias', '5.1.bn1.running_mean', '5.1.bn1.running_var', '5.1.bn1.num_batches_tracked', '5.1.conv2.weight', '5.1.bn2.weight', '5.1.bn2.bias', '5.1.bn2.running_mean', '5.1.bn2.running_var', '5.1.bn2.num_batches_tracked', '5.2.conv1.weight', '5.2.bn1.weight', '5.2.bn1.bias', '5.2.bn1.running_mean', '5.2.bn1.running_var', '5.2.bn1.num_batches_tracked', '5.2.conv2.weight', '5.2.bn2.weight', '5.2.bn2.bias', '5.2.bn2.running_mean', '5.2.bn2.running_var', '5.2.bn2.num_batches_tracked', '5.3.conv1.weight', '5.3.bn1.weight', '5.3.bn1.bias', '5.3.bn1.running_mean', '5.3.bn1.running_var', '5.3.bn1.num_batches_tracked', '5.3.conv2.weight', '5.3.bn2.weight', '5.3.bn2.bias', '5.3.bn2.running_mean', '5.3.bn2.running_var', '5.3.bn2.num_batches_tracked', '6.0.conv1.weight', '6.0.bn1.weight', '6.0.bn1.bias', '6.0.bn1.running_mean', '6.0.bn1.running_var', '6.0.bn1.num_batches_tracked', '6.0.conv2.weight', '6.0.bn2.weight', '6.0.bn2.bias', '6.0.bn2.running_mean', '6.0.bn2.running_var', '6.0.bn2.num_batches_tracked', '6.0.downsample.0.weight', '6.0.downsample.1.weight', '6.0.downsample.1.bias', '6.0.downsample.1.running_mean', '6.0.downsample.1.running_var', '6.0.downsample.1.num_batches_tracked', '6.1.conv1.weight', '6.1.bn1.weight', '6.1.bn1.bias', '6.1.bn1.running_mean', '6.1.bn1.running_var', '6.1.bn1.num_batches_tracked', '6.1.conv2.weight', '6.1.bn2.weight', '6.1.bn2.bias', '6.1.bn2.running_mean', '6.1.bn2.running_var', '6.1.bn2.num_batches_tracked', '6.2.conv1.weight', '6.2.bn1.weight', '6.2.bn1.bias', '6.2.bn1.running_mean', '6.2.bn1.running_var', '6.2.bn1.num_batches_tracked', '6.2.conv2.weight', '6.2.bn2.weight', '6.2.bn2.bias', '6.2.bn2.running_mean', '6.2.bn2.running_var', '6.2.bn2.num_batches_tracked', '6.3.conv1.weight', '6.3.bn1.weight', '6.3.bn1.bias', '6.3.bn1.running_mean', '6.3.bn1.running_var', '6.3.bn1.num_batches_tracked', '6.3.conv2.weight', '6.3.bn2.weight', '6.3.bn2.bias', '6.3.bn2.running_mean', '6.3.bn2.running_var', '6.3.bn2.num_batches_tracked', '6.4.conv1.weight', '6.4.bn1.weight', '6.4.bn1.bias', '6.4.bn1.running_mean', '6.4.bn1.running_var', '6.4.bn1.num_batches_tracked', '6.4.conv2.weight', '6.4.bn2.weight', '6.4.bn2.bias', '6.4.bn2.running_mean', '6.4.bn2.running_var', '6.4.bn2.num_batches_tracked', '6.5.conv1.weight', '6.5.bn1.weight', '6.5.bn1.bias', '6.5.bn1.running_mean', '6.5.bn1.running_var', '6.5.bn1.num_batches_tracked', '6.5.conv2.weight', '6.5.bn2.weight', '6.5.bn2.bias', '6.5.bn2.running_mean', '6.5.bn2.running_var', '6.5.bn2.num_batches_tracked', '7.0.conv1.weight', '7.0.bn1.weight', '7.0.bn1.bias', '7.0.bn1.running_mean', '7.0.bn1.running_var', '7.0.bn1.num_batches_tracked', '7.0.conv2.weight', '7.0.bn2.weight', '7.0.bn2.bias', '7.0.bn2.running_mean', '7.0.bn2.running_var', '7.0.bn2.num_batches_tracked', '7.0.downsample.0.weight', '7.0.downsample.1.weight', '7.0.downsample.1.bias', '7.0.downsample.1.running_mean', '7.0.downsample.1.running_var', '7.0.downsample.1.num_batches_tracked', '7.1.conv1.weight', '7.1.bn1.weight', '7.1.bn1.bias', '7.1.bn1.running_mean', '7.1.bn1.running_var', '7.1.bn1.num_batches_tracked', '7.1.conv2.weight', '7.1.bn2.weight', '7.1.bn2.bias', '7.1.bn2.running_mean', '7.1.bn2.running_var', '7.1.bn2.num_batches_tracked', '7.2.conv1.weight', '7.2.bn1.weight', '7.2.bn1.bias', '7.2.bn1.running_mean', '7.2.bn1.running_var', '7.2.bn1.num_batches_tracked', '7.2.conv2.weight', '7.2.bn2.weight', '7.2.bn2.bias', '7.2.bn2.running_mean', '7.2.bn2.running_var', '7.2.bn2.num_batches_tracked', '10.weight', '10.bias', '10.running_mean', '10.running_var', '10.num_batches_tracked', '12.weight', '12.bias', '14.weight', '14.bias', '14.running_mean', '14.running_var', '14.num_batches_tracked', '16.weight', '16.bias'])

model.state_dict().keys()
odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer1.2.conv1.weight', 'layer1.2.bn1.weight', 'layer1.2.bn1.bias', 'layer1.2.bn1.running_mean', 'layer1.2.bn1.running_var', 'layer1.2.bn1.num_batches_tracked', 'layer1.2.conv2.weight', 'layer1.2.bn2.weight', 'layer1.2.bn2.bias', 'layer1.2.bn2.running_mean', 'layer1.2.bn2.running_var', 'layer1.2.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.conv2.weight', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', 'layer2.0.bn2.running_mean', 'layer2.0.bn2.running_var', 'layer2.0.bn2.num_batches_tracked', 'layer2.0.downsample.0.weight', 'layer2.0.downsample.1.weight', 'layer2.0.downsample.1.bias', 'layer2.0.downsample.1.running_mean', 'layer2.0.downsample.1.running_var', 'layer2.0.downsample.1.num_batches_tracked', 'layer2.1.conv1.weight', 'layer2.1.bn1.weight', 'layer2.1.bn1.bias', 'layer2.1.bn1.running_mean', 'layer2.1.bn1.running_var', 'layer2.1.bn1.num_batches_tracked', 'layer2.1.conv2.weight', 'layer2.1.bn2.weight', 'layer2.1.bn2.bias', 'layer2.1.bn2.running_mean', 'layer2.1.bn2.running_var', 'layer2.1.bn2.num_batches_tracked', 'layer2.2.conv1.weight', 'layer2.2.bn1.weight', 'layer2.2.bn1.bias', 'layer2.2.bn1.running_mean', 'layer2.2.bn1.running_var', 'layer2.2.bn1.num_batches_tracked', 'layer2.2.conv2.weight', 'layer2.2.bn2.weight', 'layer2.2.bn2.bias', 'layer2.2.bn2.running_mean', 'layer2.2.bn2.running_var', 'layer2.2.bn2.num_batches_tracked', 'layer2.3.conv1.weight', 'layer2.3.bn1.weight', 'layer2.3.bn1.bias', 'layer2.3.bn1.running_mean', 'layer2.3.bn1.running_var', 'layer2.3.bn1.num_batches_tracked', 'layer2.3.conv2.weight', 'layer2.3.bn2.weight', 'layer2.3.bn2.bias', 'layer2.3.bn2.running_mean', 'layer2.3.bn2.running_var', 'layer2.3.bn2.num_batches_tracked', 'layer3.0.conv1.weight', 'layer3.0.bn1.weight', 'layer3.0.bn1.bias', 'layer3.0.bn1.running_mean', 'layer3.0.bn1.running_var', 'layer3.0.bn1.num_batches_tracked', 'layer3.0.conv2.weight', 'layer3.0.bn2.weight', 'layer3.0.bn2.bias', 'layer3.0.bn2.running_mean', 'layer3.0.bn2.running_var', 'layer3.0.bn2.num_batches_tracked', 'layer3.0.downsample.0.weight', 'layer3.0.downsample.1.weight', 'layer3.0.downsample.1.bias', 'layer3.0.downsample.1.running_mean', 'layer3.0.downsample.1.running_var', 'layer3.0.downsample.1.num_batches_tracked', 'layer3.1.conv1.weight', 'layer3.1.bn1.weight', 'layer3.1.bn1.bias', 'layer3.1.bn1.running_mean', 'layer3.1.bn1.running_var', 'layer3.1.bn1.num_batches_tracked', 'layer3.1.conv2.weight', 'layer3.1.bn2.weight', 'layer3.1.bn2.bias', 'layer3.1.bn2.running_mean', 'layer3.1.bn2.running_var', 'layer3.1.bn2.num_batches_tracked', 'layer3.2.conv1.weight', 'layer3.2.bn1.weight', 'layer3.2.bn1.bias', 'layer3.2.bn1.running_mean', 'layer3.2.bn1.running_var', 'layer3.2.bn1.num_batches_tracked', 'layer3.2.conv2.weight', 'layer3.2.bn2.weight', 'layer3.2.bn2.bias', 'layer3.2.bn2.running_mean', 'layer3.2.bn2.running_var', 'layer3.2.bn2.num_batches_tracked', 'layer3.3.conv1.weight', 'layer3.3.bn1.weight', 'layer3.3.bn1.bias', 'layer3.3.bn1.running_mean', 'layer3.3.bn1.running_var', 'layer3.3.bn1.num_batches_tracked', 'layer3.3.conv2.weight', 'layer3.3.bn2.weight', 'layer3.3.bn2.bias', 'layer3.3.bn2.running_mean', 'layer3.3.bn2.running_var', 'layer3.3.bn2.num_batches_tracked', 'layer3.4.conv1.weight', 'layer3.4.bn1.weight', 'layer3.4.bn1.bias', 'layer3.4.bn1.running_mean', 'layer3.4.bn1.running_var', 'layer3.4.bn1.num_batches_tracked', 'layer3.4.conv2.weight', 'layer3.4.bn2.weight', 'layer3.4.bn2.bias', 'layer3.4.bn2.running_mean', 'layer3.4.bn2.running_var', 'layer3.4.bn2.num_batches_tracked', 'layer3.5.conv1.weight', 'layer3.5.bn1.weight', 'layer3.5.bn1.bias', 'layer3.5.bn1.running_mean', 'layer3.5.bn1.running_var', 'layer3.5.bn1.num_batches_tracked', 'layer3.5.conv2.weight', 'layer3.5.bn2.weight', 'layer3.5.bn2.bias', 'layer3.5.bn2.running_mean', 'layer3.5.bn2.running_var', 'layer3.5.bn2.num_batches_tracked', 'layer4.0.conv1.weight', 'layer4.0.bn1.weight', 'layer4.0.bn1.bias', 'layer4.0.bn1.running_mean', 'layer4.0.bn1.running_var', 'layer4.0.bn1.num_batches_tracked', 'layer4.0.conv2.weight', 'layer4.0.bn2.weight', 'layer4.0.bn2.bias', 'layer4.0.bn2.running_mean', 'layer4.0.bn2.running_var', 'layer4.0.bn2.num_batches_tracked', 'layer4.0.downsample.0.weight', 'layer4.0.downsample.1.weight', 'layer4.0.downsample.1.bias', 'layer4.0.downsample.1.running_mean', 'layer4.0.downsample.1.running_var', 'layer4.0.downsample.1.num_batches_tracked', 'layer4.1.conv1.weight', 'layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn1.running_mean', 'layer4.1.bn1.running_var', 'layer4.1.bn1.num_batches_tracked', 'layer4.1.conv2.weight', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'layer4.1.bn2.running_mean', 'layer4.1.bn2.running_var', 'layer4.1.bn2.num_batches_tracked', 'layer4.2.conv1.weight', 'layer4.2.bn1.weight', 'layer4.2.bn1.bias', 'layer4.2.bn1.running_mean', 'layer4.2.bn1.running_var', 'layer4.2.bn1.num_batches_tracked', 'layer4.2.conv2.weight', 'layer4.2.bn2.weight', 'layer4.2.bn2.bias', 'layer4.2.bn2.running_mean', 'layer4.2.bn2.running_var', 'layer4.2.bn2.num_batches_tracked', 'fc.weight', 'fc.bias'])

I can see that both are not same. Can you please tell me what I can do?
I set the precompute to false before fitting the model, but not right before saving. Does that make any difference? Anyways I will try that and see in the meanwhile.

Thank you for your time

2 Likes

Hello @j.laute,
The above problem is solved simply by changing the code as follows:

model.load_state_dict(state_dict=state, strict=False)

I learned from the torchvision documentation that by default the attribute strict=True and it expcets all keys to match. Setting it to false worked.

Now, I have another issue. After passing image tensor to model I am getting following error.

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got input of size [3, 640, 640] instead

Now I wonder whether it the correct way to pass the image to a model to predict its class?

2 Likes

I found a solution for this at https://discuss.pytorch.org/t/how-to-classify-single-image-using-loaded-net/1411/2

I used fastai library to train ResNet34 with four output predictions for last layer and save the model. Next, I tryed to load the model to using Pytorch (cpu version on another computer) .

model = resnet34()
state = load(loc, device('cpu'))
model.load_state_dict(state_dict=state, strict=False)
model.eval()

When I call model(test_image) #After all required transformations I got a torch.Size([1, 1000]) tensor as predictions. I expected a tensor of size four.

After loading the state dict, the last line of model.eval() prints as follows:

(fc): Linear(in_features=512, out_features=1000, bias=True)

This is not what I expected. out_features=1000. The model I trained has only four output, does that mean the scipt did not load my model?

Did you manged to get the solution?