Lesson 5 Advanced Discussion ✅

Here is my code importing resnet50 from Cadene Pytorch pretrained models repo. I used resnet50 to make sure that I did not make any mistake which can be validated by comparing it with a normal fastai resnet50 method. The below code can give you access to ~45 models most of them still not available in fastai by simply changing the model’s name.

pretrainedmodels should be installed by:

pip install pretrainedmodels

Our lesson-1 example using pretrained models import:

%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai import *
from fastai.vision import *
import pretrainedmodels

path = untar_data(URLs.PETS); path
path_anno = path/'annotations'
path_img = path/'images'
fnames = get_image_files(path_img)
np.random.seed(2)
pat = re.compile(r'/([^/]+)_\d+.jpg$')

bs = 64

data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(),
                                   size=299, bs=bs//2).normalize(imagenet_stats)

def get_model(pretrained=True, model_name = 'resnet50', **kwargs ): 
    if pretrained:
        arch = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
    else:
        arch = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained=None)
    return arch

custom_head = create_head(nf=2048*2, nc=37, ps=0.5, bn_final=False) 

# Below you can change the imported model into any of the models available in the `pretrainedmodels` 
# which can be shown by: pretrainedmodels.model_names
fastai_resnet50=nn.Sequential(*list(children(get_model(model_name = 'resnet50'))[:-2]),custom_head) 

def get_fastai_model(pretrained=True, **kwargs ): 
    return fastai_resnet50

learn = create_cnn(data, get_fastai_model, metrics=error_rate)
learn.fit_one_cycle(5) 
learn.unfreeze()
learn.fit_one_cycle(1, max_lr=slice(1e-6,1e-4)) 

Update-1:
Maybe it is safer to create the custom_head as the following:

custom_head = create_head(nf=num_features_model(self.cnn)*2, nc=data.c, ps=0.5, bn_final=False) 

Just in case the output features of the body of a model that you will use is not 2048

Update-2:
Jeremy tweeted an even better implementation for importing architectures from cadene. What I missed in the above code, the points of split for the discriminative learning.

16 Likes