Invoke vision_learner() while offline, no torch cache

Consider this code:

train_dataloader = ImageDataLoaders.from_csv(
    apt_dataset,
    folder="train_images",
    csv_fname="train.csv",
    suff=".png",
    bs=batch_size,
    seed=42,
    item_tfms=Resize(512),
    batch_tfms=aug_transforms(size=512),
)
apt_learner = vision_learner(
    train_dataloader, resnet152, metrics=error_rate
)
apt_learner.lr_find()

I need to run it in a clean environment where I may not be able to download .pth files, and the torch cache might be empty. But I could provide the right version of resnet152.pth as part of the dataset.

How do I tell vision_learner() to skip downloading the .pth and use instead my offline file?

I guess I could use shell commands to recreate /root/.cache/torch/hub/checkpoints/... but that seems like a hack.

I tried myself at this and couldn’t find a trivial way to achieve it.
(Assuming that you want to use a torchvision model like resnet, not something from the timm library …) In general you can load pretrained weights with (in your case): weights = torch.load('path/to/resnet152.pth') and pass them to the model via:

model = resnet152()
model.load_state_dict(weigths)

…now model holds the pretrained weights. It’s a bit tricky to put this into action since you most likely have to change the head to fit your task and initialize the new layers and so on and so forth (…the tedious stuff that fastai hides away), so I think the easiest way is to copy the functions that fastai uses and modifie them.

Below code does this. I marked the bits I changed with #-# and the original functions are create_vision_model and vision_learner if you want to compare them.

#-# added `pretrain_path` parameter
def create_custom_model(arch, n_out, pretrained=True, pretrained_path=None, cut=None, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
                        concat_pool=True, pool=True, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):
    "Create custom vision architecture"
    meta = model_meta.get(arch, _default_meta)
    #-# start of added part:
    if pretrained_path:
        weights = torch.load(pretrained_path)
        model = arch()
        model.load_state_dict(weights)
    else:
    #-# end
        model = arch(pretrained=pretrained)
    body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
    nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None
    return add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=pool,
                    lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)

#-# Added pretrained_path parameter
@delegates(create_custom_model)
def custom_learner(dls, arch, normalize=True, n_out=None, pretrained=True, pretrained_path=None,
        # learner args
        loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
        model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
        # model & head args
        cut=None, init=nn.init.kaiming_normal_, custom_head=None, concat_pool=True, pool=True,
        lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None, **kwargs):
    "Build a vision learner from `dls` and `arch`"
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    meta = model_meta.get(arch, _default_meta)
    model_args = dict(init=init, custom_head=custom_head, concat_pool=concat_pool, pool=pool, lin_ftrs=lin_ftrs, ps=ps,
                      first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range, **kwargs)
    
    #-# Removed timm option
    #
    
    if normalize: _add_norm(dls, meta, pretrained)
    #-# Pass pretrained_path
    model = create_custom_model(arch, n_out, pretrained=pretrained, pretrained_path=pretrained_path, **model_args)
    
    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn, moms=moms)
    if pretrained: learn.freeze()
    # keep track of args for loggers
    store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
    return learn

Make shure that this achieves what you expect, I only checked a couple of layers for a couple of values but those matched the values from the “online” model.

learn = custom_learner(dls, resnet152, pretrained=True, pretrained_path='resnet152.pth')
layer_i =6
print(list(learn.model.parameters())[layer_i].shape)
list(learn.model.parameters())[layer_i][0,0,:2,:2]
torch.Size([64, 64, 3, 3])
tensor([[-8.0483e-09,  2.6911e-08],
        [ 1.6532e-08,  3.1734e-08]], device='cuda:0')
model = resnet152(pretrained = True)
print(list(model.parameters())[layer_i].shape)
list(model.parameters())[layer_i][0,0,:2,:2]
torch.Size([64, 64, 3, 3])

tensor([[-8.0483e-09,  2.6911e-08],
        [ 1.6532e-08,  3.1734e-08]], grad_fn=<SliceBackward0>)

Also: If someone else knows a better way to do this I’m interested too :smile:

Oh… just to state the obvious: I think your idea of throwing the *.pth file into the cache is way more elegant.