The problem
I ran into this problem as well using unet_learner
. After digging through the source code, it appears that the cut is applied inconsistently.
The traceback shows that the error occurs near the end of the function method at this line:
learn.split(ifnone(split_on, meta['split']))
Tracing this back, you’ll find that the split index is hard-coded to the value 6:
def _resnet_split(m:nn.Module): return (m[0][6],m[1])
where the passed model has already been cut by create_body
as mentioned in a previous comment. The result is that cuts values less than 7 will always fail for resnet models! Looking at the other model types, I expect they have the same problem with cutting.
A solution
The model split parameters are retrieved via the cnn_config
function which in turn just looks up values from a dictionary where the keys are the model types and the values are the split parameters. This makes it unnatural to pass in the cut parameter. To get around this, I explicitly generated the split parameters within the cnn_config
function, and added a kwarg for the cut param. Here’s my new cnn_config
function, which currently only handles resnet18 correctly:
def cnn_config(arch, **kwargs):
"Get the metadata associated with `arch`."
torch.backends.cudnn.benchmark = True
cut = kwargs.get('cut')
if kwargs.get('cut') is not None and arch == models.resnet18:
return {'cut': cut,
'split': partial(_resnet_split, cut=cut)}
else:
return model_meta.get(arch, _default_meta)
In addition, it was necessary to add a kwargs for the _resnet_split method:
# Split a resnet style model
def _resnet_split(m:nn.Module, **kwargs):
cut = kwargs.get('cut')
if cut is not None:
return (m[0][cut-1], m[1])
else:
return (m[0][6], m[1])
You can find the full source code on my forked repo of the fastai library.
Going beyond this hack
Getting this working for all resnet models should be as easy as modifying the if
condition in the cnn_config
function above. A more general solution for all model types is a bit harder. It would be easier if rather than use the intermediate dictionary of split parameters, the cnn_config
function directly called the model’s appropriate _<model>_split
function passing in kwargs such as cut
. I think the resulting code would also be much more transparent.