Create a DynamicUnet using an architecture from timm

Hey, for a project I would like to use a Unet. I can do that using the following line of codes :

m = resnet34()
m = nn.Sequential(*list(m.children())[:-2])
model =DynamicUnet(m, 3, (img_size,img_size), norm_type=None)

However, I would like to use a model from timm as the backbone of my architecture for better performances. I have tried the following :

m2 = timm.create_model('efficientnet_es', num_classes=0, global_pool='')
m2 = m = nn.Sequential(*list(m.children()))
model = DynamicUnet(m2, 3, img_size=(224,224))

which gives me this error :

lib/python3.9/site-packages/timm/models/levit.py:292, in Attention.forward(self, x)
    290     x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
    291 else:
--> 292     B, N, C = x.shape
    293     q, k, v = self.qkv(x).view(
    294         B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
    295     q = q.permute(0, 2, 1, 3)

ValueError: too many values to unpack (expected 3)

or this version :

m2 = timm.create_model('efficientnet_es', num_classes=0, global_pool='')
model = DynamicUnet(m2, 3, img_size=(224,224))

which gives me this error :

lib/python3.9/site-packages/fastai/callback/hook.py:51, in Hooks.__init__(self, ms, hook_func, is_forward, detach, cpu)
     50 def __init__(self, ms, hook_func, is_forward=True, detach=True, cpu=False):
---> 51     self.hooks = [Hook(m, hook_func, is_forward, detach, cpu) for m in ms]

TypeError: 'EfficientNet' object is not iterable

It looks to me as if it just a matter of the DynamicUnet not working directly with that model which brings me to these three questions :

  • Is there a quickfix to DynamicUnet to work with this particular model
  • Given that some models in timm can already extract feature for a unet using features_only, is it plan to extand the unet creation functionality of fastai to those models ?
  • If that is the case, should I look into it more and make a PR ?

Thanks for your answers !

Hi there!

Never tried this one, had to try it on my machine as well.
I think the issue comes from you having a typo rather than the system not working

m2 = m = nn.Sequential(list....

I do not think the = m belongs there, does it?

Without that it worked on a computer with Timm 0.6.12, fastai 2.7.11 and torch 1.13.1

Edit: works also for efficientnet_es

2 Likes

Indeed it was a typo. Thanks a lot !

1 Like