Cutting Pretrained Models

When I do :

arch = models.resnet34
cut = 5
meta = cnn_config(arch)
body = create_body(arch(pretrained=True), ifnone(cut,meta['cut']))
body

I get the model cut at the 5th block as desired. However when I pass cut into the learner I get:

learn = create_cnn(data, models.resnet34,cut=5, pretrained=False,metrics=error_rate)
IndexError: index 6 is out of range

Is this a bug?

3 Likes

Hi @bluesky314,

Did you manage to understand why this error was being returned?

The source code is simple enough: nn.Sequential(*list(model.children())[:cut])

Somehow, however, all numbers that are in range throw out this error.

Any help would be gratefully appreciated.

1 Like

The problem

I ran into this problem as well using unet_learner. After digging through the source code, it appears that the cut is applied inconsistently.

The traceback shows that the error occurs near the end of the function method at this line:

learn.split(ifnone(split_on, meta['split']))

Tracing this back, you’ll find that the split index is hard-coded to the value 6:

def _resnet_split(m:nn.Module): return (m[0][6],m[1])

where the passed model has already been cut by create_body as mentioned in a previous comment. The result is that cuts values less than 7 will always fail for resnet models! Looking at the other model types, I expect they have the same problem with cutting.

A solution

The model split parameters are retrieved via the cnn_config function which in turn just looks up values from a dictionary where the keys are the model types and the values are the split parameters. This makes it unnatural to pass in the cut parameter. To get around this, I explicitly generated the split parameters within the cnn_config function, and added a kwarg for the cut param. Here’s my new cnn_config function, which currently only handles resnet18 correctly:

def cnn_config(arch, **kwargs):
    "Get the metadata associated with `arch`."
    torch.backends.cudnn.benchmark = True
    cut = kwargs.get('cut')
    if kwargs.get('cut') is not None and arch == models.resnet18:
        return {'cut': cut,
                'split': partial(_resnet_split, cut=cut)}
    else:
        return model_meta.get(arch, _default_meta)

In addition, it was necessary to add a kwargs for the _resnet_split method:

# Split a resnet style model
def _resnet_split(m:nn.Module, **kwargs):
    cut = kwargs.get('cut')
    if cut is not None:
        return (m[0][cut-1], m[1])
    else:
        return (m[0][6], m[1])

You can find the full source code on my forked repo of the fastai library.

Going beyond this hack

Getting this working for all resnet models should be as easy as modifying the if condition in the cnn_config function above. A more general solution for all model types is a bit harder. It would be easier if rather than use the intermediate dictionary of split parameters, the cnn_config function directly called the model’s appropriate _<model>_split function passing in kwargs such as cut. I think the resulting code would also be much more transparent.

1 Like

I posted an issue about this. @sgugger pointed out that you can get around this in a much simpler way, so I wanted to share here for anyone running into this issue. The vision learner methods accept a split_on callback method to override the default split behavior.

For anyone looking to cut resnet models, it’s as simple as using a function like this:

def split_fn(m, cut):
  return (m[0][cut-1], m[1])

Then in the learner pass cut=5 and split_on = partial(split_fn, cut=5) to cut to the first five layers of the encoder and split at the last layer.

2 Likes

Note that you wouldn’t necessarily want to have your first group go from 0 to cut-1 in that case. Maybe

def split_fn(m, cut):
  return (m[0][cut//2], m[1])

would be more adapted (or any custom value < cut).

2 Likes

Is the v2 code imrpoved, so we don’t need to hack around to solve it?