Current implementation of FocalLossFlat() conflicts Mixup() callback

FocalLossFlat() uses self.reduce to control reduction but Mixup() callback only checks self.reduction

After modify some code it works fine.

class NoneReduce_v2():
    def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None

    def __enter__(self):
        if hasattr(self.loss_func, 'reduce'):
            self.old_red = self.loss_func.reduce
            self.loss_func.reduce = 'none'
            return self.loss_func
        else: return partial(self.loss_func, reduce='none')

    def __exit__(self, type, value, traceback):
        if self.old_red is not None: self.loss_func.reduce = self.old_red

class MixUp_v2(MixUp):
    def __init__(self, alpha=.4): super().__init__(alpha)

    def lf(self, pred, *yb):
        if not self.training: return self.old_lf(pred, *yb)
        with NoneReduce_v2(self.old_lf) as lf:
            loss = torch.lerp(lf(pred, *self.yb1), lf(pred, *yb), self.lam)
        return reduce_loss(loss, getattr(self.old_lf, 'reduce', 'mean'))

learn = cnn_learner(
    dls,
    resnet50,
    loss_func=FocalLossFlat(),
    metrics=accuracy,
    cbs=MixUp_v2())

However error raise if learn.lr_find() is run before learn.fine_tune()

RuntimeError: expected dtype long int for `weights` but got dtype float

If I skip learn.lr_find(), then everything is alright.

2 Likes