It looks creat_body may not work for all the PyTorch models as it is, as it expects the model to have certain characteristics, like the model creation class to have pretrained
as its first argument.
Let’s take an example of densenet from pretrainedmodels
package models and see how to use create_body on it.
Let’s take a look at create_body code.
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
"Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function)."
model = arch(pretrained)
cut = ifnone(cut, cnn_config(arch)['cut'])
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif isinstance(cut, Callable): return cut(model)
else: raise NamedError("cut must be either integer or a function")
The function passes pretrained to the arch
argument, where we would be passing typically our model.
Let’s also look at how we can create a densenet
model from pretrainedmodels
.
densenet = pretrainedmodels.densenet121(num_classes=1000, pretrained='imagenet')
If you pass directly pretrainedmodels.densenet121 directly to create_body it fails as it passes pretrained value to denset function for num_classes.
Let’s create a custom function which accepts arguments in a way that satisfies both create_body and densenet.
def densenet(pretrained):
return pretrainedmodels.densenet121(pretrained='imagenet') if pretrained else pretrainedmodels.densenet121(pretrained=None)
Lets try passing our custom model to create_body. It successfully runs and returns empty Sequential object.
create_body(densenet)
It returns an empty object as fastai cnn_config method does not know anything about the densenet
architecture. Taking a close look at the densent
model we will realize we want create_body function to cut the densenet
model just before the last layer to create the model body. So passing -1 to cut should get our body of densenet
.
create_body(densenet,cut=-1)
For other architectures cut value may not be simple -1, in such cases you can pass a function which knows how to split the model.