Using timm pretrained models for semantic segmentation

I have been following (walk with fastai v2) and experimenting with multiclass semantic segmentation notebook.
When i train on any resnet, i do not get this error, but when i create a timm model and put it in uner_learner, i get this error:

TypeError: forward() got an unexpected keyword argument 'pretrained'

Here is how i create the model.

import timm
from import *
pretrained_model = timm.create_model('vit_base_patch16_224_in21k')
learn = unet_learner(dls, pretrained_model, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt, cbs=callbacks)

Any help is appreciated!

1 Like


To specify an architecture for unet_learner, you must pass a function that returns the desired backbone and not the backbone itself. Python’s partial module can be utilized to construct a function that yields vit_base_patch16_224_in21k.

from functools import partial

pretrained_model = partial(timm.create_model, 'vit_base_patch16_224_in21k')
learn = unet_learner(dls, pretrained_model, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt, cbs=callbacks)
1 Like

Thanks a lot for you help @BobMcDear

i did exactly as you mentioned, but for some reason i get the following error. any insights?

StopIteration                             Traceback (most recent call last)
<ipython-input-21-e9aff7a89710> in <module>
      5         ReduceLROnPlateau(monitor='valid_loss', min_delta=0.0001,patience=2, factor=1e-1, min_lr=0),GradientAccumulation(n_acc=Grad_acc)]
      6 opt = ranger
----> 7 learn = unet_learner(dls, pretrained_model, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt, cbs=callbacks, loss_func=loss_func)

~/anaconda3/envs/280322CubiAClone/lib/python3.6/site-packages/fastai/vision/ in unet_learner(dls, arch, normalize, n_out, pretrained, config, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
    218     img_size = dls.one_batch()[0].shape[-2:]
    219     assert img_size, "image size could not be inferred from data"
--> 220     model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)
    222     splitter=ifnone(splitter, meta['split'])

~/anaconda3/envs/280322CubiAClone/lib/python3.6/site-packages/fastai/vision/ in create_unet_model(arch, n_out, img_size, pretrained, cut, n_in, **kwargs)
    193     "Create custom unet architecture"
    194     meta = model_meta.get(arch, _default_meta)
--> 195     body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))
    196     model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)
    197     return model

~/anaconda3/envs/280322CubiAClone/lib/python3.6/site-packages/fastai/vision/ in create_body(arch, n_in, pretrained, cut)
     68     if cut is None:
     69         ll = list(enumerate(model.children()))
---> 70         cut = next(i for i,o in reversed(ll) if has_pool_type(o))
     71     if   isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
     72     elif callable(cut): return cut(model)



functools.partial(<function create_model at 0x7f99420ad730>, 'vit_base_patch16_224_in21k')


found the solution here:

The source of the StopIteration exception is that U-Nets require hierarchical backbones, but the vision transformer is an isotropic architecture with a fixed resolution throughout the entire network and thus does not interact well with unet_learner. You should instead use a standard CNN or a hierarchical ViT.

This thread actually addresses an unrelated problem and is not applicable here.

I have changed the backbone to (Effecientnet_b0) once i found out that batch_size wasn’t the issue!

Thanks for the explanation!

1 Like