Problem creating custom loss function


(David) #1

I am trying to create and use a custom loss function. When my initial attempts failed I decided to take a step back and implement (through cut and paste) the standard loss function used with a unet Learner in my own notebook. I thought this would be a good way to check my understanding of the size of the tensor inputs and see where the inputs differed between the standard loss function and the ones I first created.

To my disappointment my “cut and paste” loss function also does not work in that an exception is thrown during lr_find.

/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1786     if input.size(0) != target.size(0):
   1787         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1788                          .format(input.size(0), target.size(0)))
   1789     if dim == 2:
   1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (65536) to match target batch_size (8192).

I would appreciate some insight into what I am doing wrong.

Initial standard fastai code which does work:

wd=1e-2
learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
print('Loss func ', learn.loss_func)

Output:
Loss func FlattenedLoss of CrossEntropyLoss()

Here is the code I’ve pasted in (and renamed) that fails.

class MyFlattenedLoss():
    "Same as `func`, but flattens input and target."
    def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
        self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d

    def __repr__(self): return f"My FlattenedLoss of {self.func}"
    @property
    def reduction(self): return self.func.reduction
    @reduction.setter
    def reduction(self, v): self.func.reduction = v

    def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        
        input = input.transpose(self.axis,-1).contiguous()
        target = target.transpose(self.axis,-1).contiguous()
        
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        
        if self.floatify: target = target.float()
        input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
        
        print('input shape ', input.shape)
        print('target shape ', target.shape)
        print('floatify', self.floatify, ' 2d ', self.is_2d)
        print('kwargs ', kwargs)
        print('Func ', self.func)
        print('target view ', target.view(-1).shape)
        return self.func.__call__(input, target.view(-1), **kwargs)    
    


def MyCrossEntropyFlat(*args, axis:int=-1, **kwargs):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    return MyFlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)

wd=1e-2
​learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
learn.loss_func = MyCrossEntropyFlat()
print('Loss func ', learn.loss_func)

Output:
Loss func My FlattenedLoss of CrossEntropyLoss()

Exception occurs calling lr_find

lr_find(learn)

Note that the learner is setup to use a batch size of 8, there are 256 classes, and the images have been
specified to be resized to [32,32]

The following output is captured before the exception:

input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([65536, 32])
target shape  torch.Size([8, 1, 32, 32])
floatify False  2d  True
kwargs  {}
Func  CrossEntropyLoss()
target view  torch.Size([8192])

(Renato Hermoza) #2

Try: learn.loss_func = MyCrossEntropyFlat(axis=1), thats the channel that indicates the labels.


(David) #3

Thank you! Specifying the axis index solved the issue.