Trainable ensemble learner for image segmentation

Dear All,

I’m trying to create a model to combine the predictions from three segmentation models. The code below takes in model paths, load in self.models, mark them not to train via requires_grad=False. The forward method calls these models to get predictions which are collated and eventually passed onto the combiner unet_learner. Is this a sensible approach? Besides, can you please guide what am I doing wrong to get rid of the error posted at the end of the post:

class EnsembleModel(nn.Module):
    def __init__(self, model_paths):
        super(EnsembleModel, self).__init__()
        self.models = []
        
        # Load the ensemble models using the file names as paths
        for path in model_paths:
            learn = unet_learner(
                dls,
                partial(timm.create_model, path),
                metrics=[foreground_acc, DiceMulti()],
                self_attention=True,
                loss_func=FocalLossFlat(axis=1)
            ).to_fp16()

            learn.load(path)
            self.models.append(learn.model)
        
        # Freeze the parameters of individual models
        for model in self.models:
            for param in model.parameters():
                param.requires_grad = False
        
        # Define the combiner model
        self.combiner = unet_learner(
            dls,
            resnet18,
            metrics=[foreground_acc, DiceMulti()],
            self_attention=True,
            loss_func=FocalLossFlat(axis=1)
        )
        
    def forward(self, x):
        ensemble_masks = []
        
        # Generate masks from each ensemble model
        for model in self.models:
            mask = model(x)
            # Remove the class dimension using argmax
            mask = torch.argmax(mask, dim=1)
            ensemble_masks.append(mask)
        
        # Combine the ensemble predictions
        combined_pred = torch.stack(ensemble_masks, dim=1)
        
        # Pass the combined predictions through the combiner model
        final_output = self.combiner.model(combined_pred)
        
        return final_output

The ensemble class is used for training the model below:

# Step 2: Create ensemble learner
ensemble_model = EnsembleModel(['convnext_small_in22k', 'regnetx_032'])
learner = unet_learner(dls, resnet34, metrics=[foreground_acc, DiceMulti()], self_attention=True)
learner.model = ensemble_model
learner.fit_one_cycle(10, 1e-3)

The code generate the following error:

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In [7], line 2
      1 # Train the ensemble learner
----> 2 learner.fit_one_cycle(10, 1e-3)

File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/schedule.py:119, in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
    116 lr_max = np.array([h['lr'] for h in self.opt.hypers])
    117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    118           'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:264, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
    262 self.opt.set_hypers(lr=self.lr if lr is None else lr)
    263 self.n_epoch = n_epoch
--> 264 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:253, in Learner._do_fit(self)
    251 for epoch in range(self.n_epoch):
    252     self.epoch=epoch
--> 253     self._with_events(self._do_epoch, 'epoch', CancelEpochException)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:247, in Learner._do_epoch(self)
    246 def _do_epoch(self):
--> 247     self._do_epoch_train()
    248     self._do_epoch_validate()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:239, in Learner._do_epoch_train(self)
    237 def _do_epoch_train(self):
    238     self.dl = self.dls.train
--> 239     self._with_events(self.all_batches, 'train', CancelTrainException)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
    198 def _with_events(self, f, event_type, ex, final=noop):
--> 199     try: self(f'before_{event_type}');  f()
    200     except ex: self(f'after_cancel_{event_type}')
    201     self(f'after_{event_type}');  final()

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:205, in Learner.all_batches(self)
    203 def all_batches(self):
    204     self.n_iter = len(self.dl)
--> 205     for o in enumerate(self.dl): self.one_batch(*o)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:233, in Learner.one_batch(self, i, b)
    231 def one_batch(self, i, b):
    232     self.iter = i
--> 233     b = self._set_device(b)
    234     self._split(b)
    235     self._with_events(self._do_one_batch, 'batch', CancelBatchException)

File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:226, in Learner._set_device(self, b)
    225 def _set_device(self, b):
--> 226     model_device = next(self.model.parameters()).device
    227     dls_device = getattr(self.dls, 'device', default_device())
    228     if model_device == dls_device: return to_device(b, dls_device)

StopIteration: 
1 Like