Hi…
I suspect there is a bug in mixup.py. I get an error when i train model with Mse loss. With type of output as FLoatList
.i get an error below is error trail. Output size is bs x classes so in my case it is bs x 1 and this if condition calculates the size to b 2 and then converts the target to long and that is when in the loss we get a below error.
What is purpose of here calculating 2 losses and converting target to long instead of Float.
if len(target.size()) == 2:
---> 56 loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
57 d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
58 else: d = self.crit(output, target)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
441 @weak_script_method
442 def forward(self, input, target):
--> 443 return F.mse_loss(input, target, reduction=self.reduction)
444
445
/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
2255 else:
2256 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
-> 2257 ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2258 return ret
2259
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'target'
class MixUpLoss(Module):
"Adapt the loss function `crit` to go with mixup."
def __init__(self, crit, reduction='mean'):
super().__init__()
if hasattr(crit, 'reduction'):
self.crit = crit
self.old_red = crit.reduction
setattr(self.crit, 'reduction', 'none')
else:
self.crit = partial(crit, reduction='none')
self.old_crit = crit
self.reduction = reduction
def forward(self, output, target):
if len(target.size()) == 2:
loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
else: d = self.crit(output, target)
if self.reduction == 'mean': return d.mean()
elif self.reduction == 'sum': return d.sum()
return d