I have a use case to normalize an incoming image with pytorch, load a bestmodel.pth trained with fastai and then do a prediction. i got as far as being able to load a standard pretrained model, but get the following error when trying to load a bestmodel.pth.(both have the same exact arch)
Thereafter, how might one predict? using learn.predict(img) gives an error as well.
:
import pretrainedmodels
import io
from PIL import Image
import requests
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
# Random cat img taken from Google
IMG_URL = 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg'
# Class labels used when training VGG as json, courtesy of the 'Example code' link above.
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
# Let's get our class labels.
response = requests.get(LABELS_URL) # Make an HTTP GET request and store the response.
labels = {int(key): value for key, value in response.json().items()}
# Let's get the cat img.
response = requests.get(IMG_URL)
img = Image.open(io.BytesIO(response.content)) # Read bytes and store as an img.
# Let's take a look at this cat!
img.show()
# Now that we have an img, we need to preprocess it.
# We need to:
# * resize the img, it is pretty big (~1200x1200px).
# * normalize it, as noted in the PyTorch pretrained models doc,
# with, mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
# * convert it to a PyTorch Tensor.
#
# We can do all this preprocessing using a transform pipeline.
min_img_size = 224 # The min size, as noted in the PyTorch pretrained models doc, is 224 px.
transform_pipeline = transforms.Compose([transforms.Resize(min_img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
img = transform_pipeline(img)
# PyTorch pretrained models expect the Tensor dims to be (num input imgs, num color channels, height, width).
# Currently however, we have (num color channels, height, width); let's fix this by inserting a new axis.
img = img.unsqueeze(-1) # Insert the new axis at index 0 i.e. in front of the other axes/dims.
# Now that we have preprocessed our img, we need to convert it into a
# Variable; PyTorch models expect inputs to be Variables. A PyTorch Variable is a
# wrapper around a PyTorch Tensor.
img = Variable(img)
#try Fastai
#learn = load_learner('./dir')
#learn.predict(img)
# Now let's load our model and get a prediciton!
#model = pretrainedmodels.se_resnext101_32x4d(pretrained=True) # This may take a few minutes.
model_name = 'se_resnext101_32x4d'
model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
#load fastai model
state = torch.load('./dir/bestmodel.pth')
model.load_state_dict(state['model'])
Got this following runtime error:
RuntimeError Traceback (most recent call last)
in ()
79 #load fastai model
80 state = torch.load('./models.catsplitv2seresnext101-1/bestmodel.pth')
---> 81 model.load_state_dict(state['model'])
82
83 #prediction = model(img) # Returns a Tensor of shape (batch, num class labels)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
775 if len(error_msgs) > 0:
776 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 777 self.__class__.__name__, "\n\t".join(error_msgs)))
778 return _IncompatibleKeys(missing_keys, unexpected_keys)
779
RuntimeError: Error(s) in loading state_dict for SENet:
Missing key(s) in state_dict: "layer0.conv1.weight", "layer0.bn1.weight", "layer0.bn1.bias", "layer0.bn1.running_mean", "layer0.bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.se_module.fc1.weight", "layer1.0.se_module.fc1.bias", "layer1.0.se_module.fc2.weight", "layer1.0.se_module.fc2.bias", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.se_module.fc1.weight", "layer1.1.se_module.fc1.bias", "layer1.1.se_module.fc2.weight", "layer1.1.se_module.fc2.bias", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weig...
Unexpected key(s) in state_dict: "0.0.conv1.weight", "0.0.bn1.weight", "0.0.bn1.bias", "0.0.bn1.running_mean", "0.0.bn1.running_var", "0.0.bn1.num_batches_tracked", "0.1.0.conv1.weight", "0.1.0.bn1.weight", "0.1.0.bn1.bias", "0.1.0.bn1.running_mean", "0.1.0.bn1.running_var", "0.1.0.bn1.num_batches_tracked", "0.1.0.conv2.weight", "0.1.0.bn2.weight", "0.1.0.bn2.bias", "0.1.0.bn2.running_mean", "0.1.0.bn2.running_var", "0.1.0.bn2.num_batches_tracked", "0.1.0.conv3.weight", "0.1.0.bn3.weight", "0.1.0.bn3.bias", "0.1.0.bn3.running_mean", "0.1.0.bn3.running_var", "0.1.0.bn3.num_batches_tracked", "0.1.0.se_module.fc1.weight", "0.1.0.se_module.fc1.bias", "0.1.0.se_module.fc2.weight", "0.1.0.se_module.fc2.bias", "0.1.0.downsample.0.weight", "0.1.0.downsample.1.weight", "0.1.0.downsample.1.bias", "0.1.0.downsample.1.running_mean", "0.1.0.downsample.1.running_var", "0.1.0.downsample.1.num_batches_tracked", "0.1.1.conv1.weight", "0.1.1.bn1.weight", "0.1.1.bn1.bias", "0.1.1.bn1.running_mean", "0.1.1.bn1.running_var", "0.1.1.bn1.num_batches_tracked", "0.1.1.conv2.weight", "0.1.1.bn2.weight", "0.1.1.bn2.bias", "0.1.1.bn2.running_mean", "0.1.1.bn2.running_var", "0.1.1.bn2.num_batches_tracked", "0.1.1.conv3.weight", "0.1.1.bn3.weight", "0.1.1.bn3.bias", "0.1.1.bn3.running_mean", "0.1.1.bn3.running_var", "0.1.1.bn3.num_batches_tracked", "0.1.1.se_module.fc1.weight", "0.1.1.se_module.fc1.bias", "0.1.1.se_module.fc2.weight", "0.1.1.se_module.fc2.bias", "0.1.2.conv1.weight", "0.1.2.bn1.weight...