How to reload and train the learner model for new data

Hey everyone, I have trained my retinanet model on some data with 3 classes(including background). Now, I want to use that learner model to train on new dataset loaded successfully but has 2 classes(including background).
How can I train my previously saved learner model for this new data?
This is the code for the training and loading data:

train_images = list(np.random.choice(training_set, train_samples_per_scanner))
print('training_images =',len(train_images))
valid_images = list(np.random.choice(valid_set, val_samples_per_scanner))
print('validation_images =',len(valid_images))

batch_size = 64

do_flip = True
flip_vert = True 
max_rotate = 90 
max_zoom = 1.1 
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75 
p_lighting = 0.75 

tfms = get_transforms(do_flip=do_flip,

train, valid = ObjectItemListSlide(train_images), ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize()
model = RetinaNet(encoder, n_classes=data.train_ds.c, 
                  n_anchors=len(scales) * len(ratios), 
                  sizes=[size[0] for size in sizes], 
                  chs=channels, # number of hidden layers for the classification head
                  n_conv=n_conv # Number of hidden layers

voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
learn = Learner(data, model, loss_func=crit, 

I saved this model using torch .save as follows:,PATH)

When I try to load this model for new data it shows an error:



/usr/local/lib/python3.7/dist-packages/torch/nn/modules/ in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for RetinaNet:
	size mismatch for classifier.3.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([2, 128, 3, 3]).
	size mismatch for classifier.3.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

How can I make my older model state_dict work for present data?
Also, I tried using:,PATH)
And it loads without any error but when I do it throws me error in the first iteration:

max_learning_rate = 1e-3
cyc_len = 50
learn.fit_one_cycle(cyc_len, max_learning_rate,callbacks=[SaveModelCallback(learn, monitor='train_loss', name='best_loss')])

The error

     13 learn.fit_one_cycle(cyc_len, max_learning_rate,callbacks=[SaveModelCallback(learn, monitor='train_loss',
---> 14                                                                             name='model`')])

7 frames
/usr/local/lib/python3.7/dist-packages/fastai/ in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
     21     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
     22                                        final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 23, max_lr, wd=wd, callbacks=callbacks)
     25 def fit_fc(learn:Learner, tot_epochs:int=1,,  moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,

/usr/local/lib/python3.7/dist-packages/fastai/ in fit(self, epochs, lr, wd, callbacks)
    198         else:,self.opt.wd = lr,wd
    199         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/usr/local/lib/python3.7/dist-packages/fastai/ in fit(epochs, learn, callbacks, metrics)
    104             if not cb_handler.skip_validate and not
    105                 val_loss = validate(learn.model,, loss_func=learn.loss_func,
--> 106                                        cb_handler=cb_handler, pbar=pbar)
    107             else: val_loss=None
    108             if cb_handler.on_epoch_end(val_loss): break

/usr/local/lib/python3.7/dist-packages/fastai/ in validate(model, dl, loss_func, cb_handler, pbar, average, n_batch)
     61             if not is_listy(yb): yb = [yb]
     62             nums.append(first_el(yb).shape[0])
---> 63             if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
     64             if n_batch and (len(nums)>=n_batch): break
     65         nums = np.array(nums, dtype=np.float32)

/usr/local/lib/python3.7/dist-packages/fastai/ in on_batch_end(self, loss)
    306         "Handle end of processing one batch with `loss`."
    307         self.state_dict['last_loss'] = loss
--> 308         self('batch_end', call_mets = not self.state_dict['train'])
    309         if self.state_dict['train']:
    310             self.state_dict['iteration'] += 1

/usr/local/lib/python3.7/dist-packages/fastai/ in __call__(self, cb_name, call_mets, **kwargs)
    248         "Call through to all of the `CallbakHandler` functions."
    249         if call_mets:
--> 250             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
    251         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)

/usr/local/lib/python3.7/dist-packages/fastai/ in _call_and_update(self, cb, cb_name, **kwargs)
    239     def _call_and_update(self, cb, cb_name, **kwargs)->None:
    240         "Call `cb_name` on `cb` and update the inner state."
--> 241         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
    242         for k,v in new.items():
    243             if k not in self.state_dict:

/usr/local/lib/python3.7/dist-packages/object_detection_fastai/callbacks/ in on_batch_end(self, last_output, last_target, **kwargs)
    153             num_boxes = len(bbox_gt) * 3
    154             for box, cla, scor in list(zip(bbox_pred, preds, scores))[:num_boxes]:
--> 155                 temp = BoundingBox(imageName=str(self.imageCounter), classId=self.metric_names_original[cla], x=box[0], y=box[1],
    156                                    w=box[2], h=box[3], typeCoordinates=CoordinatesType.Absolute, classConfidence=scor,
    157                                    bbType=BBType.Detected, format=BBFormat.XYWH, imgSize=(self.size, self.size))

IndexError: list index out of range

Does anyone know where is it going wrong and how I can use the trained model for 3 classes to work for 2 class data?
@jeremy @sgugger tagging for support.
Thank you all in advance for this forum.

Don’t do that.

1 Like

Okay, sorry for mentioning.