I’m playing around with PlantClf2016 dataset, my models has trouble differentiating between some of the classes, so I’ve decided to give it a go.
So I’m first training the model on classes except the one that we’re hard to differ, then I save the model.
Then I’d like to train on the whole data set, but I’m getting an tensor size mismatch (duh, because second dataset has more classes).
RuntimeError: Error(s) in loading state_dict for Sequential:
size mismatch for 1.8.weight: copying a param with shape torch.Size([934, 512]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
size mismatch for 1.8.bias: copying a param with shape torch.Size() from checkpoint, the shape in current model is torch.Size().
Code (exluded classes is a list of classnames, didn’t include because it’s visually big):
Edit: Missed exactly what you were doing.
As @ste says, you should reserve room in your model for extra classes. You can also then probably avoid calling cnn_learner again and instead just update the data within the learner (or construct a basic Learner with model and new data). The model from your first training should be identical to your second to be able to load weights so cnn_learner is not needed the second time.
Create a list of all_classes and use it in the classes parameter of both calls to labels_from_func. You’re still passing different data to the two steps but reserve room for future classes on your model (data.c should be the same in the two passes, otherwise you’ve to tweak your weights, reserving room for new classes…)
So pass all classes, but don’t pass the data for the excluded classes? I’ve got the idea in the first place from that post.
Is there any way to easily filter out files with given classes from ds?
I’ve tried something like this, but I still can’t overwrite train_ds in databubunch train_ds = [item for item in data.train_ds if str(item) not in excluded_classes]
all_items = ImageList.from_folder(path_img)
easy_items = ImageList.from_folder(path_img).filter_by_func(lambda x: get_label_from_xml(x) not in excluded_classes)
(Note that filter_by_func is in-place so you need to apply it to a new ImageList)
Construct a DataBunch from both sets of items (eplit, label, making sure to use all classes for the partial set as well, normalise etc). So now you have easy_data and all_data databunches. Now use cnn_learner(easy_data, arch) to create the learner, train it for a bit, here using just easy data. Then rather than constructing a new cnn_learner instead do learn.data = all_data as in that post. Then train some more, now with all data but the same model.
You might also want to try unfreezing after training on easy_data for a bit, training some more, then switch to all_data and freeze again. And repeat the train-unfreeze-train on all_data.