I was thinking of using Fastai (v2) to experiment with self-training/self-distillation, where I use a trained classifier’s predicted probabilities as pseudo-labels to train a student network on. This seems to be a pretty promising area of improving a model’s generalization performance. However, I’m not sure how to set the predicted probabilities as new labels. I assume I will have to create a custom dataloader for this.Does anyone have any resources to do this in either fastai v1 or v2?
I recently used Knowledge Distillation on mobilenetv2 to increase accuracy from 84% to 90% ! I used fastai v1 for this. First, I trained a Resnet50. Then, I looped over all the images, and created a python dict. The keys of that python dict are filenames and values are output logits from Resnet50 (before softmax, this is not technically Knowledge distillation from Geoffrey Hinton’s paper, as it uses a temperature with softmax, but using only logits worked quite well). Assuming output variable stores logits, use output.numpy().tolist() to convert that to a python list. Then I stored this dict. in a json file.
For student training, I wrote a method to accept filename and return logits from Resnet50. Then, I used label from func, passed that method and used TensorDataSet as type of label. Finally, changed the loss function to MSE and trained mobilenetv2.
I also tried using ensemble of Resnet34 and Resnet50, took average of their logits to train Mobilenetv2 and got an additional 2% increase in accuracy on my dataset.
That’s really cool! Do you have any code you may be able to share?
I don’t think I can share the code because it’s at my workplace. I am planning to write a basic blog about it soon after studying bit more theory. Till then, try this approach and comment if you face any issues
So the python dict of logits you’re storing to use as pseudolabels, is it something like
dict[filename1] = [logit_1, logit_2, ... , logit_n]
Also why did you use MSE as the loss function? I believe the knowledge distillation paper still uses entropy on predicted classes
Also thanks for your help. I plan on start working on this during my free time, but if you ever publish your blog post I’d be happy to read it
Yup, the dict was like this (I had 6 classes) -
dict[filename] = [logit_class1,…,logit_class6]
This was directly loaded into torch tensor in label_from_func method.
I used cross entropy for training the teacher model. For training the student model, I used MSE. If I am not wrong, we are trying to minimise the distance between teacher logit vectors and student logit vectors. This is the main purpose of using logits, as soft labels provide more information on class distribution compared to hard one-hot encoded labels.
Another thing I think I read somewhere or maybe it’s my idea is to take MSE loss of both FC layers from head of CNN and also the logits, average that and use it as loss. Kinda inspired from NoGAN approach where we also use intermediate conv layers from VGG16 to train Rn34. It’s also recommended in the paper to try predicting both soft labels and hard labels and use their average as loss.
Thanks for the write up Jayesh,
So In short you changed the classification problem into multicategory regression problem, using
I came up with this implementation in a short period of time, so there might be better ways of doing this and achieving better results (it’s always good to check github). But this acts as a good starting point
What was your loss function? I tried using
nn.MSELoss as well as a custom
def custom_mse(predicted, target): total_mse = 0 for i in range(target.shape): total_mse+=nn.MSELoss(predicted[i], target[i]) return total_mse
but I get
RuntimeError: bool value of Tensor with more than one value is ambiguous
You should be using an instance of
MSELoss here, IE:
Changing it to
def custom_mse(predicted, target): total_mse = 0 for i in range(target.shape): total_mse+=nn.MSELoss()(predicted[i], target[i]) return total_mse
Seems to have fixed the issue, thanks!
Heuristically, is there a difference between using MSE vs nll loss (with soft labels)? I would think that there isn’t, because both loss functions are trying to predict the right predicted probability. But I could be missing something
Upon further inspection, it seemed like the
custom_mse had several errors, but using
loss_func = nn.MSELoss() seems to give the right result, I must have forgotten to add the
() when I was initially testing it. So using the default MSE loss seems to be fine.
Share your Blog Post Please
I have a blog post that covers such techniques. This area is called semi-supervised learning and there are a bunch of ways you can formulate it.
See if the images there help.