It looks creat_body may not work for all the PyTorch models as it is, as it expects the model to have certain characteristics, like the model creation class to have
pretrained as its first argument.
Let’s take an example of densenet from
pretrainedmodels package models and see how to use create_body on it.
Let’s take a look at create_body code.
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
"Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function)."
model = arch(pretrained)
cut = ifnone(cut, cnn_config(arch)['cut'])
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif isinstance(cut, Callable): return cut(model)
else: raise NamedError("cut must be either integer or a function")
The function passes pretrained to the
arch argument, where we would be passing typically our model.
Let’s also look at how we can create a
densenet model from
densenet = pretrainedmodels.densenet121(num_classes=1000, pretrained='imagenet')
If you pass directly pretrainedmodels.densenet121 directly to create_body it fails as it passes pretrained value to denset function for num_classes.
Let’s create a custom function which accepts arguments in a way that satisfies both create_body and densenet.
return pretrainedmodels.densenet121(pretrained='imagenet') if pretrained else pretrainedmodels.densenet121(pretrained=None)
Lets try passing our custom model to create_body. It successfully runs and returns empty Sequential object.
It returns an empty object as fastai cnn_config method does not know anything about the
densenet architecture. Taking a close look at the
densent model we will realize we want create_body function to cut the
densenet model just before the last layer to create the model body. So passing -1 to cut should get our body of
For other architectures cut value may not be simple -1, in such cases you can pass a function which knows how to split the model.