Knowledge Distillation in fast.ai

Hello everyone,

Wanted to bring up this very important, and somewhat neglected, area of DL research.

Given the importance fast.ai has put on transfer learning, and the incredible results we get out of it, would it make sense for an implementation of some version of knowledge distillation to the library?

Now I understand that doesn’t need to be a class or object doing this, but it would be pretty nice if we had a somewhat efficient coding method to apply KD. If that is asking too much, then perhaps some pointers as to how current iteration of fast.ai could be used for this would be very helpful for me!

Thanks for your thoughts.

Kind regards,
Theodore.

1 Like

The best way to handle a project like this is to just start coding! If you get stuck along the way, let us know and we can try to help :blush:

3 Likes

In its simplest form, knowledge distillation is simply using a larger teacher model to create the labels that you’re using to train the student model.

So instead of getting the labels from the training set, you would first run the teacher model on the batch to get the labels, and then run the student model on the batch as usual. Not sure how easy fast.ai makes this but shouldn’t be too difficult with the new callbacks system.

You can try many variations of this, such as using a weighted combination of the original labels from the training set and the prediction from the teacher model.

3 Likes

Thanks will do so soon I hope! Still in the process of generating data and training teacher models! Will add any developments here.

I am trying knowledge distillation on the dynamic u-net. However I need someway to access outputs before applying softmax. I am trying to use call backs for it. But I am not able to understand how to get the output just before softmax and apply softmax temperature to it.
@jeremy

You can get the predictions prior to loss like this.

class OnLossBegin(Callback):
    def on_loss_begin(self, last_output, **kwargs)->Any:
		print(len(last_output))
        return {'last_output': last_output}

 learn = cnn_learner(data, models.resnet34, metrics=error_rate, callbacks=[OnLossBegin()])

If you just want to change the default loss function to a custom one, then just define a new one and map it to assign it to learn.loss_func

1 Like

However in the case of knowledge distillation, you would need to pass 3 inputs to the custom loss function ie., the student-predictions, teacher-predictions and the true labels.

You can create a custom databunch to load 3 values (X, y and y_pred_teacher) however AFAIK there is no provision in callbacks to send the 3 values (y_pred_student, y, y_pred_teacher) to the loss function.

In such case, you would have to modify loss_batch in basic_train.py to accept 3 values and change loss = loss_func(out, *yb) to loss = loss_func(out, *yb, *y_pred_teacher)

Also would need to change all references loss_batch like
loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
to
loss_batch(model, xb, yb, y_pred_teacher loss_func, cb_handler=cb_handler)

@sgugger Please let me know if my understanding is right.

No, if your dataloader returns tuple of targets with 2 values: (x, (y,y_teacher)) then *yb will deconstruct that tuple and you can have a loos function with the following signature:

def my_loss_func(out, y, y_teacher):
...
1 Like

I have modified the Imagelist to load filenames and target tuple (y,t_teach), but it reads as MultiCategoryList now.

Train: LabelList (48000 items)
x: ImageList
Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)
y: MultiCategoryList
6;6,0;0,1;1,6;6,1;1
Path: /;

Also created the custom loss function

def loss_fn_kd(outputs, labels, teacher_outputs):
    alpha = 0.9
    T = 20
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

However code errors out during fit_one_cycle

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     29     if not loss_func: return to_detach(out), to_detach(yb[0])
---> 30     loss = loss_func(out, *yb)
     31 
     32     if opt is not None:

TypeError: loss_fn_kd() missing 1 required positional argument: 'teacher_outputs'

Appreciate if you could have a look at the gist.

If anybody is still interested in doing Knowledge Distillation with fastai, I have implemented an example here.

The implementation is inspired from this paper, which also has an implementation in pure Pytorch.

The way I implemented it was to create a Callback, which takes the logits of a teacher model and returns the DistillationLoss at each step.

class KnowledgeDistillation(LearnerCallback):
    def __init__(self, learn:Learner, teacher:Learner, T:float=20., α:float=0.7):
        super().__init__(learn)
        self.teacher = teacher
        self.T, self.α = T, α
    
    def on_backward_begin(self, last_input, last_output, last_target, **kwargs):
        teacher_output = self.teacher.model(last_input)
        new_loss = DistillationLoss(last_output, last_target, teacher_output, self.T, self.α)
        
        return {'last_loss': new_loss}

def DistillationLoss(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y/T, dim=-1), F.softmax(teacher_scores/T, dim=-1)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

You can then use it by training your teacher and then, passing the callback when you train your student model as:

student_learner.fit_one_cycle(3, 1e-3, callbacks=[KnowledgeDistillation(student_learner, teacher=teacher_learner)])
7 Likes

Nice, thanks for sharing!