Missing 'opt' argument, using a custom model

Hi Everyone,
I am trying to create a custom model for semantic segmentation task based on this article.

My custom model:

net = AttU_Net()
criterion = nn.CrossEntropyLoss()
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))
learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func)
learn.fit(5)

When I run the code above, I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_33/895282974.py in <module>
      3 def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))
      4 learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func)
----> 5 learn.fit(5)
      6 #learn = unet_learner(dls, resnet34, loss_func=CrossEntropyLossFlat(axis=1), metrics=[foreground_acc, cust_foreground_acc])
      7 # learn.lr_find()

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    214     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
    215         with self.added_cbs(cbs):
--> 216             if reset_opt or not self.opt: self.create_opt()
    217             if wd is None: wd = self.wd
    218             if wd is not None: self.opt.set_hypers(wd=wd)

/opt/conda/lib/python3.7/site-packages/fastai/learner.py in create_opt(self)
    150             if 'lr' in self.opt_func.keywords:
    151                 self.lr = self.opt_func.keywords['lr']
--> 152         self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
    153         if not self.wd_bn_bias:
    154             for p in self._bn_bias_state(True ): p['do_wd'] = False

/tmp/ipykernel_33/895282974.py in opt_func(params, **kwargs)
      1 net = AttU_Net()
      2 criterion = nn.CrossEntropyLoss()
----> 3 def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))
      4 learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func)
      5 learn.fit(5)

TypeError: __init__() missing 1 required positional argument: 'opt'

59 / 5.000

Is this perhaps due to the custom architecture used?

You should read the official port in the documentation from my article. You need to explicitly pass in opt a bit differently:

1 Like