Dear All,
I’m trying to create a model to combine the predictions from three segmentation models. The code below takes in model paths, load in self.models, mark them not to train via requires_grad=False
. The forward
method calls these models to get predictions which are collated and eventually passed onto the combiner unet_learner
. Is this a sensible approach? Besides, can you please guide what am I doing wrong to get rid of the error posted at the end of the post:
class EnsembleModel(nn.Module):
def __init__(self, model_paths):
super(EnsembleModel, self).__init__()
self.models = []
# Load the ensemble models using the file names as paths
for path in model_paths:
learn = unet_learner(
dls,
partial(timm.create_model, path),
metrics=[foreground_acc, DiceMulti()],
self_attention=True,
loss_func=FocalLossFlat(axis=1)
).to_fp16()
learn.load(path)
self.models.append(learn.model)
# Freeze the parameters of individual models
for model in self.models:
for param in model.parameters():
param.requires_grad = False
# Define the combiner model
self.combiner = unet_learner(
dls,
resnet18,
metrics=[foreground_acc, DiceMulti()],
self_attention=True,
loss_func=FocalLossFlat(axis=1)
)
def forward(self, x):
ensemble_masks = []
# Generate masks from each ensemble model
for model in self.models:
mask = model(x)
# Remove the class dimension using argmax
mask = torch.argmax(mask, dim=1)
ensemble_masks.append(mask)
# Combine the ensemble predictions
combined_pred = torch.stack(ensemble_masks, dim=1)
# Pass the combined predictions through the combiner model
final_output = self.combiner.model(combined_pred)
return final_output
The ensemble class is used for training the model below:
# Step 2: Create ensemble learner
ensemble_model = EnsembleModel(['convnext_small_in22k', 'regnetx_032'])
learner = unet_learner(dls, resnet34, metrics=[foreground_acc, DiceMulti()], self_attention=True)
learner.model = ensemble_model
learner.fit_one_cycle(10, 1e-3)
The code generate the following error:
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
Cell In [7], line 2
1 # Train the ensemble learner
----> 2 learner.fit_one_cycle(10, 1e-3)
File ~/mambaforge/lib/python3.10/site-packages/fastai/callback/schedule.py:119, in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt, start_epoch)
116 lr_max = np.array([h['lr'] for h in self.opt.hypers])
117 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
118 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 119 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=start_epoch)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:264, in Learner.fit(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)
262 self.opt.set_hypers(lr=self.lr if lr is None else lr)
263 self.n_epoch = n_epoch
--> 264 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:253, in Learner._do_fit(self)
251 for epoch in range(self.n_epoch):
252 self.epoch=epoch
--> 253 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:247, in Learner._do_epoch(self)
246 def _do_epoch(self):
--> 247 self._do_epoch_train()
248 self._do_epoch_validate()
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:239, in Learner._do_epoch_train(self)
237 def _do_epoch_train(self):
238 self.dl = self.dls.train
--> 239 self._with_events(self.all_batches, 'train', CancelTrainException)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:205, in Learner.all_batches(self)
203 def all_batches(self):
204 self.n_iter = len(self.dl)
--> 205 for o in enumerate(self.dl): self.one_batch(*o)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:233, in Learner.one_batch(self, i, b)
231 def one_batch(self, i, b):
232 self.iter = i
--> 233 b = self._set_device(b)
234 self._split(b)
235 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
File ~/mambaforge/lib/python3.10/site-packages/fastai/learner.py:226, in Learner._set_device(self, b)
225 def _set_device(self, b):
--> 226 model_device = next(self.model.parameters()).device
227 dls_device = getattr(self.dls, 'device', default_device())
228 if model_device == dls_device: return to_device(b, dls_device)
StopIteration: