Bug in Mixup?

Hi…
I suspect there is a bug in mixup.py. I get an error when i train model with Mse loss. With type of output as FLoatList
.i get an error below is error trail. Output size is bs x classes so in my case it is bs x 1 and this if condition calculates the size to b 2 and then converts the target to long and that is when in the loss we get a below error.
What is purpose of here calculating 2 losses and converting target to long instead of Float.

if len(target.size()) == 2:
---> 56             loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
     57             d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
     58         else:  d = self.crit(output, target)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    441     @weak_script_method
    442     def forward(self, input, target):
--> 443         return F.mse_loss(input, target, reduction=self.reduction)
    444 
    445 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
   2255     else:
   2256         expanded_input, expanded_target = torch.broadcast_tensors(input, target)
-> 2257         ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
   2258     return ret
   2259 

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'target'


class MixUpLoss(Module):
    "Adapt the loss function `crit` to go with mixup."
    
    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'): 
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else: 
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction
        
    def forward(self, output, target):
        if len(target.size()) == 2:
            loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
            d = (loss1 * target[:,2] + loss2 * (1-target[:,2])).mean()
        else:  d = self.crit(output, target)
        if self.reduction == 'mean': return d.mean()
        elif self.reduction == 'sum':            return d.sum()
        return d
1 Like

here is fix i did for target having size equal 1 cross H cross W

class MixUpCallback1(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=False):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
    
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
        
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.new(lambd)
        #print(lambd.size(),'lambd')
        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
        if self.stack_x:
            new_input = [last_input, last_input[shuffle], lambd]
        else: 
            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
        if self.stack_y:
            print(last_target[:,None].size(),y1[:,None].size(),lambd[:,None].float().size())
            new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd[:,None].float()], 1)
        else:
            if len(last_target.shape) == 2:
                lambd = lambd.unsqueeze(1).float()
            out_shape = [lambd.size(0)] + [1 for _ in range(len(y1.shape) - 1)]
            new_target = last_target.half() * lambd.view(out_shape)  + y1.half() * (1-lambd).view(out_shape)
        return {'last_input': new_input, 'last_target': new_target}  
    
    def on_train_end(self, **kwargs):
        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()
1 Like