Error in a Learner Subclass

We are trying to make instance segmentation in this other post

However, I am facing the issue that Torchvision Mask-RCNN forward method need image and target in training and just image in validation.

In addition, it outputs losses in training and detection in inference.

During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing:

  • boxes ( FloatTensor[N, 4] ): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x between 0 and W and values of y between 0 and H
  • labels ( Int64Tensor[N] ): the class label for each ground-truth box
  • masks ( UInt8Tensor[N, H, W] ): the segmentation binary masks for each instance

The model returns a Dict[Tensor] during training, containing the classification and regression losses for both the RPN and the R-CNN, and the mask loss.

During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]] , one for each input image. The fields of the Dict are as follows:

  • boxes ( FloatTensor[N, 4] ): the predicted boxes in [x1, y1, x2, y2] format, with values of x between 0 and W and values of y between 0 and H
  • labels ( Int64Tensor[N] ): the predicted labels for each image
  • scores ( Tensor[N] ): the scores or each prediction
  • masks ( UInt8Tensor[N, 1, H, W] ): the predicted masks for each instance, in 0-1 range. In order to obtain the final segmentation masks, the soft masks can be thresholded, generally with a value of 0.5 ( mask >= 0.5 )

Owing to this fact I decided to subclass Learner

class Mask_RCNN_Learner(Learner):
    def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
                 metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
                 moms=(0.95,0.85,0.95)):
        super().__init__(dls, model, loss_func, opt_func, lr, splitter, cbs,
                 metrics, path, model_dir, wd, wd_bn_bias, train_bn,
                 moms)
        
    def _do_epoch_train(self):
        try:
            self.dl = self.dls.train;                                     self('begin_train')
            
            # Modification
            self.n_iter = len(self.dl)
            for o in enumerate(self.dl):
                i, b = *o
                self.iter = i
                try:
                    self._split(b);                                       self('begin_batch_train')
                    loss_dict = self.model(*self.xb,*self.yb);            self('after_pred_train')
                    if len(self.yb) == 0: return
                    self.loss = sum(loss for loss in loss_dict.values()); self('after_loss_train')
                    if not self.training: return
                    self.loss.backward();                                 self('after_backward_train')
                    self.opt.step();                                      self('after_step_train')
                    self.opt.zero_grad()
                except CancelBatchException:                              self('after_cancel_batch_train')
                finally:                                                  self('after_batch_train')
            
        except CancelTrainException:                                      self('after_cancel_train')
        finally:                                                          self('after_train')

    def _do_epoch_validate(self, ds_idx=1, dl=None):
        if dl is None: dl = self.dls[ds_idx]
        try:
            self.dl = dl;                                                 self('begin_validate')
            with torch.no_grad():
                # Modification
                self.n_iter = len(self.dl)
                for o in enumerate(self.dl):
                    i, b = *o
                    self.iter = i
                    try:
                        self._split(b);                                   self('begin_batch_validate')
                        detection = self.model(*self.xb);                 self('after_pred_validate')
                        self.loss =  self.loss_func(self.pred, *self.yb); self('after_loss_validate')
                        
                        # COMPUTING METRICS
                        if not self.training: return
                    except CancelBatchException:                          self('after_cancel_batch_validate')
                    finally:                                              self('after_batch_train')
                
        except CancelValidException:                                      self('after_cancel_validate')
        finally:                                                          self('after_validate')
    
    @log_args(but='cbs')
    def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):
        with self.added_cbs(cbs):
            if reset_opt or not self.opt: self.create_opt()
            if wd is None: wd = self.wd
            if wd is not None: self.opt.set_hypers(wd=wd)
            self.opt.set_hypers(lr=self.lr if lr is None else lr)

            try:
                self._do_begin_fit(n_epoch)
                for epoch in range(n_epoch):
                    try:
                        self.epoch=epoch;          self('begin_epoch')
                        self._do_epoch_train()
                        self._do_epoch_validate()
                    except CancelEpochException:   self('after_cancel_epoch')
                    finally:                       self('after_epoch')

            except CancelFitException:             self('after_cancel_fit')
            finally:                               self('after_fit') 

The problem that I am facing is that I don’t know if this is the proper way of doing it.

When I execute a cell with all that code I am getting the next error:

 File "<ipython-input-33-447bef347370>", line 19
    self._split(b);                                  self('begin_batch_train')
              ^
SyntaxError: can't use starred expression here

In addition, I don’t know how to set what are the inputs to the metrics parameter that we pass in the __init__ method of the Learner.

1 Like

You may use a callback to drop/modify the output of the model according to the state (training/inference).

Example from the transformers tutorial.

class DropOutput(Callback):
    def after_pred(self): self.learn.pred = self.pred[0]

https://github.com/fastai/fastai2/blob/master/nbs/39_tutorial.transformers.ipynb