FocalLossFlat() uses self.reduce to control reduction but Mixup() callback only checks self.reduction
After modify some code it works fine.
class NoneReduce_v2():
def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None
def __enter__(self):
if hasattr(self.loss_func, 'reduce'):
self.old_red = self.loss_func.reduce
self.loss_func.reduce = 'none'
return self.loss_func
else: return partial(self.loss_func, reduce='none')
def __exit__(self, type, value, traceback):
if self.old_red is not None: self.loss_func.reduce = self.old_red
class MixUp_v2(MixUp):
def __init__(self, alpha=.4): super().__init__(alpha)
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
with NoneReduce_v2(self.old_lf) as lf:
loss = torch.lerp(lf(pred, *self.yb1), lf(pred, *yb), self.lam)
return reduce_loss(loss, getattr(self.old_lf, 'reduce', 'mean'))
learn = cnn_learner(
dls,
resnet50,
loss_func=FocalLossFlat(),
metrics=accuracy,
cbs=MixUp_v2())
However error raise if learn.lr_find() is run before learn.fine_tune()
RuntimeError: expected dtype long int for `weights` but got dtype float
If I skip learn.lr_find(), then everything is alright.