Here is my code importing resnet50
from Cadene Pytorch pretrained models repo. I used resnet50 to make sure that I did not make any mistake which can be validated by comparing it with a normal fastai resnet50 method. The below code can give you access to ~45 models most of them still not available in fastai by simply changing the model’s name.
pretrainedmodels should be installed by:
pip install pretrainedmodels
Our lesson-1 example using pretrained models import:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai import *
from fastai.vision import *
import pretrainedmodels
path = untar_data(URLs.PETS); path
path_anno = path/'annotations'
path_img = path/'images'
fnames = get_image_files(path_img)
np.random.seed(2)
pat = re.compile(r'/([^/]+)_\d+.jpg$')
bs = 64
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),
size=299, bs=bs//2).normalize(imagenet_stats)
def get_model(pretrained=True, model_name = 'resnet50', **kwargs ):
if pretrained:
arch = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
else:
arch = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained=None)
return arch
custom_head = create_head(nf=2048*2, nc=37, ps=0.5, bn_final=False)
# Below you can change the imported model into any of the models available in the `pretrainedmodels`
# which can be shown by: pretrainedmodels.model_names
fastai_resnet50=nn.Sequential(*list(children(get_model(model_name = 'resnet50'))[:-2]),custom_head)
def get_fastai_model(pretrained=True, **kwargs ):
return fastai_resnet50
learn = create_cnn(data, get_fastai_model, metrics=error_rate)
learn.fit_one_cycle(5)
learn.unfreeze()
learn.fit_one_cycle(1, max_lr=slice(1e-6,1e-4))
Update-1:
Maybe it is safer to create the custom_head
as the following:
custom_head = create_head(nf=num_features_model(self.cnn)*2, nc=data.c, ps=0.5, bn_final=False)
Just in case the output features of the body of a model that you will use is not 2048
Update-2:
Jeremy tweeted an even better implementation for importing architectures from cadene. What I missed in the above code, the points of split for the discriminative learning.