I tried myself at this and couldn’t find a trivial way to achieve it.
(Assuming that you want to use a torchvision model like resnet, not something from the timm library …) In general you can load pretrained weights with (in your case): weights = torch.load('path/to/resnet152.pth')
and pass them to the model via:
model = resnet152()
model.load_state_dict(weigths)
…now model
holds the pretrained weights. It’s a bit tricky to put this into action since you most likely have to change the head to fit your task and initialize the new layers and so on and so forth (…the tedious stuff that fastai hides away), so I think the easiest way is to copy the functions that fastai uses and modifie them.
Below code does this. I marked the bits I changed with #-#
and the original functions are create_vision_model
and vision_learner
if you want to compare them.
#-# added `pretrain_path` parameter
def create_custom_model(arch, n_out, pretrained=True, pretrained_path=None, cut=None, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
concat_pool=True, pool=True, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):
"Create custom vision architecture"
meta = model_meta.get(arch, _default_meta)
#-# start of added part:
if pretrained_path:
weights = torch.load(pretrained_path)
model = arch()
model.load_state_dict(weights)
else:
#-# end
model = arch(pretrained=pretrained)
body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))
nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None
return add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=pool,
lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)
#-# Added pretrained_path parameter
@delegates(create_custom_model)
def custom_learner(dls, arch, normalize=True, n_out=None, pretrained=True, pretrained_path=None,
# learner args
loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
# model & head args
cut=None, init=nn.init.kaiming_normal_, custom_head=None, concat_pool=True, pool=True,
lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None, **kwargs):
"Build a vision learner from `dls` and `arch`"
if n_out is None: n_out = get_c(dls)
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
meta = model_meta.get(arch, _default_meta)
model_args = dict(init=init, custom_head=custom_head, concat_pool=concat_pool, pool=pool, lin_ftrs=lin_ftrs, ps=ps,
first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range, **kwargs)
#-# Removed timm option
#
if normalize: _add_norm(dls, meta, pretrained)
#-# Pass pretrained_path
model = create_custom_model(arch, n_out, pretrained=pretrained, pretrained_path=pretrained_path, **model_args)
splitter=ifnone(splitter, meta['split'])
learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn, moms=moms)
if pretrained: learn.freeze()
# keep track of args for loggers
store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
return learn
Make shure that this achieves what you expect, I only checked a couple of layers for a couple of values but those matched the values from the “online” model.
learn = custom_learner(dls, resnet152, pretrained=True, pretrained_path='resnet152.pth')
layer_i =6
print(list(learn.model.parameters())[layer_i].shape)
list(learn.model.parameters())[layer_i][0,0,:2,:2]
torch.Size([64, 64, 3, 3])
tensor([[-8.0483e-09, 2.6911e-08],
[ 1.6532e-08, 3.1734e-08]], device='cuda:0')
model = resnet152(pretrained = True)
print(list(model.parameters())[layer_i].shape)
list(model.parameters())[layer_i][0,0,:2,:2]
torch.Size([64, 64, 3, 3])
tensor([[-8.0483e-09, 2.6911e-08],
[ 1.6532e-08, 3.1734e-08]], grad_fn=<SliceBackward0>)
Also: If someone else knows a better way to do this I’m interested too