Creating a fit_distill function

I’ve been trying to create a knowledge distillation training framework in fastai by modifying the fit function to take in the otuputs of both a teacher and student learner. The code so far looks something like this

def distillation(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

def loss_batch_distill(model:nn.Module, teacher_model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
               cb_handler:Optional[CallbackHandler]=None):
    "Calculate loss and metrics for a batch, call out to callbacks as necessary."
    cb_handler = ifnone(cb_handler, CallbackHandler())
    if not is_listy(xb): xb = [xb]
    if not is_listy(yb): yb = [yb]
    out = model(*xb)
    teacher_scores = teacher_model(*xb) # outputs from teacher model
    out = cb_handler.on_loss_begin(out)

    if not loss_func: return to_detach(out), to_detach(yb[0])
    loss = loss_func(out, *yb, teacher_scores, T=20., alpha = 0.7)

    if opt is not None:
        loss,skip_bwd = cb_handler.on_backward_begin(loss)
        if not skip_bwd:                     loss.backward()
        if not cb_handler.on_backward_end(): opt.step()
        if not cb_handler.on_step_end():     opt.zero_grad()

    return loss.detach().cpu()

def fit_distill(epochs:int, learn:BasicLearner, learn_teacher:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None)->None:
    "Fit the `model` on `data` and learn using `loss_func` and `opt`."
    assert len(learn.data.train_dl) != 0, f"""Your training dataloader is empty, can't train a model.
        Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements)."""
    cb_handler = CallbackHandler(callbacks, metrics)
    pbar = master_bar(range(epochs))
    cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)

    exception=False
    try:
        for epoch in pbar:
            learn.model.train()
            learn_teacher.model.eval() # setting teacher to eval
            cb_handler.set_dl(learn.data.train_dl)
            cb_handler.on_epoch_begin()
            for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
                xb, yb = cb_handler.on_batch_begin(xb, yb)
                loss = loss_batch_distill(learn.model, learn_teacher.model, xb, yb, learn.loss_func, learn.opt, cb_handler) # change loss_batch
                if cb_handler.on_batch_end(loss): break

            if not cb_handler.skip_validate and not learn.data.empty_val:
                val_loss = validate(learn.model, learn.data.valid_dl, loss_func=learn.loss_func,
                                       cb_handler=cb_handler, pbar=pbar)
            else: val_loss=None
            if cb_handler.on_epoch_end(val_loss): break
    except Exception as e:
        exception = e
        raise
    finally: cb_handler.on_train_end(exception)

When I try to run

learn_student = cnn_learner(data, models.resnet18, pretrained = True, loss_func = distillation,
               metrics = [accuracy])

fit_distill(3, learn_student, learn)

I get

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1958: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.
  warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size."

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-35-207d9d4098ab> in <module>()
      1 fit_distill(3,
      2             learn_student,
----> 3             learn)

2 frames

/usr/local/lib/python3.6/dist-packages/fastai/callback.py in step(self)
     53                 for p in pg1['params']: p.data.mul_(1 - wd*lr)
     54                 if self.bn_wd:
---> 55                     for p in pg2['params']: p.data.mul_(1 - wd*lr)
     56             self.set_val('weight_decay', listify(0, self._wd))
     57         self.opt.step()

TypeError: unsupported operand type(s) for *: 'float' and 'slice'

Does anyone know how I can fix this error? The notebook can be found here

It looks like you are not properly setting the values for the lr in the optimizer. Be sure to copy the latest version of Learner.fit

Is there a way to modify the Learner object and create a Learner.distill without going in and modifying basic_train.py? I am thinking of doing something like

def distill(self, epochs:int, lr:Union[Floats,slice]=defaults.lr,
        wd:Floats=None, callbacks:Collection[Callback]=None)->None:
    ...
    fit_distill(epochs, self, learner_teacher, metrics=self.metrics, callbacks=self.callbacks+callbacks)

learn_student.distill = distill

Will this work?