Hey everyone, I have trained my retinanet model on some data with 3 classes(including background). Now, I want to use that learner model to train on new dataset loaded successfully but has 2 classes(including background).
How can I train my previously saved learner model for this new data?
This is the code for the training and loading data:
train_images = list(np.random.choice(training_set, train_samples_per_scanner))
print('training_images =',len(train_images))
valid_images = list(np.random.choice(valid_set, val_samples_per_scanner))
print('validation_images =',len(valid_images))
batch_size = 64
do_flip = True
flip_vert = True
max_rotate = 90
max_zoom = 1.1
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75
p_lighting = 0.75
tfms = get_transforms(do_flip=do_flip,
flip_vert=flip_vert,
max_rotate=max_rotate,
max_zoom=max_zoom,
max_lighting=max_lighting,
max_warp=max_warp,
p_affine=p_affine,
p_lighting=p_lighting)
train, valid = ObjectItemListSlide(train_images), ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize()
model = RetinaNet(encoder, n_classes=data.train_ds.c,
n_anchors=len(scales) * len(ratios),
sizes=[size[0] for size in sizes],
chs=channels, # number of hidden layers for the classification head
final_bias=final_bias,
n_conv=n_conv # Number of hidden layers
)
voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
voc
learn = Learner(data, model, loss_func=crit,
callback_fns=[BBMetrics,ShowGraph],metrics=[voc])
I saved this model using torch .save as follows:
torch.save(learn.model.state_dict(),PATH)
When I try to load this model for new data it shows an error:
learn.model.load_state_dict(torch.load(PATH))
Error:
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1496 if len(error_msgs) > 0:
1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498 self.__class__.__name__, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
1500
RuntimeError: Error(s) in loading state_dict for RetinaNet:
size mismatch for classifier.3.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([2, 128, 3, 3]).
size mismatch for classifier.3.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).
How can I make my older model state_dict work for present data?
Also, I tried using:
Torch.save(learner.model,PATH)
And it loads without any error but when I do learn.fit it throws me error in the first iteration:
max_learning_rate = 1e-3
cyc_len = 50
batch_size=16
learn.fit_one_cycle(cyc_len, max_learning_rate,callbacks=[SaveModelCallback(learn, monitor='train_loss', name='best_loss')])
The error
13 learn.fit_one_cycle(cyc_len, max_learning_rate,callbacks=[SaveModelCallback(learn, monitor='train_loss',
---> 14 name='model`')])
7 frames
/usr/local/lib/python3.7/dist-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
21 callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
22 final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 23 learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
24
25 def fit_fc(learn:Learner, tot_epochs:int=1, lr:float=defaults.lr, moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,
/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
198 else: self.opt.lr,self.opt.wd = lr,wd
199 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
--> 200 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
201
202 def create_opt(self, lr:Floats, wd:Floats=0.)->None:
/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
104 if not cb_handler.skip_validate and not learn.data.empty_val:
105 val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
--> 106 cb_handler=cb_handler, pbar=pbar)
107 else: val_loss=None
108 if cb_handler.on_epoch_end(val_loss): break
/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in validate(model, dl, loss_func, cb_handler, pbar, average, n_batch)
61 if not is_listy(yb): yb = [yb]
62 nums.append(first_el(yb).shape[0])
---> 63 if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
64 if n_batch and (len(nums)>=n_batch): break
65 nums = np.array(nums, dtype=np.float32)
/usr/local/lib/python3.7/dist-packages/fastai/callback.py in on_batch_end(self, loss)
306 "Handle end of processing one batch with `loss`."
307 self.state_dict['last_loss'] = loss
--> 308 self('batch_end', call_mets = not self.state_dict['train'])
309 if self.state_dict['train']:
310 self.state_dict['iteration'] += 1
/usr/local/lib/python3.7/dist-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
248 "Call through to all of the `CallbakHandler` functions."
249 if call_mets:
--> 250 for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
251 for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
252
/usr/local/lib/python3.7/dist-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
239 def _call_and_update(self, cb, cb_name, **kwargs)->None:
240 "Call `cb_name` on `cb` and update the inner state."
--> 241 new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
242 for k,v in new.items():
243 if k not in self.state_dict:
/usr/local/lib/python3.7/dist-packages/object_detection_fastai/callbacks/callbacks.py in on_batch_end(self, last_output, last_target, **kwargs)
153 num_boxes = len(bbox_gt) * 3
154 for box, cla, scor in list(zip(bbox_pred, preds, scores))[:num_boxes]:
--> 155 temp = BoundingBox(imageName=str(self.imageCounter), classId=self.metric_names_original[cla], x=box[0], y=box[1],
156 w=box[2], h=box[3], typeCoordinates=CoordinatesType.Absolute, classConfidence=scor,
157 bbType=BBType.Detected, format=BBFormat.XYWH, imgSize=(self.size, self.size))
IndexError: list index out of range
Does anyone know where is it going wrong and how I can use the trained model for 3 classes to work for 2 class data?
@jeremy @sgugger tagging for support.
Thank you all in advance for this forum.
Harshit