I am trying to use FixRes method from https://github.com/facebookresearch/FixRes within fastai learner.
There is an example code for loading pretrained models provided in the repo:
https://github.com/facebookresearch/FixRes#pre-trained-networks
How could I use this and integrate it in cnn_learner
like: arch = resnext101_32x48d_wsl
and learner = cnn_learner(data, arch, [accuracy])
? Obviously I tried and it didn’t work.
I searched the forum for a way to load custom weights for custom model but couldn’t make it work
I usually pass my custom models in the Learner like so learn = Learner(data, model)
and it works fine. Have you tried it ? I think that cnn_learner
only works for some architectures predefined in the library.
Hope that helps
Hi Nathan,
Is the model an architecture like resnet50
or an instance like resnet50()
?
I tried Learner(data, model)
as you suggested and encounterer this error: `/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in post_init(self)
163 “Setup path,metrics, callbacks and ensure model directory exists.”
164 self.path = Path(ifnone(self.path, self.data.path))
–> 165 self.model = self.model.to(self.data.device)
166 self.loss_func = self.loss_func or self.data.loss_func
167 self.metrics=listify(self.metrics)
AttributeError: module ‘FixRes.imnet_finetune.resnext_wsl’ has no attribute ‘to’`
Can you show me an example of your custom model or how to properly construct one @NathanHub?
Thank you
I was able to use the method with arch = resnext101_32x48d_wsl
and pass it to learner = cnn_learner(data, arch)
, then load the model provided in FixRes repo by calling learner.model blah blah
Actually it didn’t work out, cause when the arch
is passed in cnn_learner
, the keys
in its state_dict
don’t hold the original name, so nothing matched between the 2 state_dict
I’m still looking for an workaround
Here is a simple gist where I create a Learner from a custom model.
Also, when you use the cnn_learner
, you should specify pretrained=False
, as it is True by default.
I revisited the docs and could clarify that cnn_learner
take base_arch
as a Callable
, while Learner
take model
as a module
, your suggestion worked for me. Thank you, have a good day @NathanHub