I ran into the following error when I tried to use the new timm integration functionality. The problem was that I had installed timm after importing fastai.vision.all. To fix this, I just needed to restart my notebook so the import was done again.
Here was the code I ran:
learn = vision_learner(dls, 'ghostnet_050', pretrained=False, metrics=accuracy, opt_func=Adam, wd=1e-5)
and here was the stack trace:
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_38235/3334928029.py in <module>
----> 1 learn = vision_learner(dls, 'ghostnet_050', pretrained=False, metrics=accuracy, opt_func=Adam, wd=1e-5)
/data/fastai/vision/learner.py in vision_learner(dls, arch, normalize, n_out, pretrained, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
202 meta = model_meta.get(arch, _default_meta)
203 if normalize: _add_norm(dls, meta, pretrained)
--> 204 if isinstance(arch, str): model = create_timm_model(arch, n_out, default_split, pretrained, **kwargs)
205 else: model = create_vision_model(arch, n_out, pretrained=pretrained, **kwargs)
206
/data/fastai/vision/learner.py in create_timm_model(arch, n_out, cut, pretrained, n_in, init, custom_head, concat_pool, **kwargs)
178 concat_pool=True, **kwargs):
179 "Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
--> 180 body = TimmBody(arch, pretrained, None, n_in)
181 nf = body.model.num_features
182 return add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=body.needs_pool, **kwargs)
/data/fastai/vision/learner.py in __init__(self, arch, pretrained, cut, n_in)
167 def __init__(self, arch:str, pretrained:bool=True, cut=None, n_in:int=3):
168 super().__init__()
--> 169 model = timm.create_model(arch, pretrained=pretrained, num_classes=0, in_chans=n_in)
170 self.needs_pool = model.default_cfg.get('pool_size', None)
171 self.model = model if cut is None else cut_model(model, cut)
NameError: name 'timm' is not defined