 # Problem creating custom loss function

(David) #1

I am trying to create and use a custom loss function. When my initial attempts failed I decided to take a step back and implement (through cut and paste) the standard loss function used with a unet Learner in my own notebook. I thought this would be a good way to check my understanding of the size of the tensor inputs and see where the inputs differed between the standard loss function and the ones I first created.

To my disappointment my “cut and paste” loss function also does not work in that an exception is thrown during lr_find.

``````/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1786     if input.size(0) != target.size(0):
1787         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1788                          .format(input.size(0), target.size(0)))
1789     if dim == 2:
1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (65536) to match target batch_size (8192).
``````

I would appreciate some insight into what I am doing wrong.

Initial standard fastai code which does work:

``````wd=1e-2
learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
print('Loss func ', learn.loss_func)
``````

Output:
Loss func FlattenedLoss of CrossEntropyLoss()

Here is the code I’ve pasted in (and renamed) that fails.

``````class MyFlattenedLoss():
"Same as `func`, but flattens input and target."
def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d

def __repr__(self): return f"My FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v

def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
print('input shape ', input.shape)
print('target shape ', target.shape)

input = input.transpose(self.axis,-1).contiguous()
target = target.transpose(self.axis,-1).contiguous()

print('input shape ', input.shape)
print('target shape ', target.shape)

if self.floatify: target = target.float()
input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)

print('input shape ', input.shape)
print('target shape ', target.shape)
print('floatify', self.floatify, ' 2d ', self.is_2d)
print('kwargs ', kwargs)
print('Func ', self.func)
print('target view ', target.view(-1).shape)
return self.func.__call__(input, target.view(-1), **kwargs)

def MyCrossEntropyFlat(*args, axis:int=-1, **kwargs):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
return MyFlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)

wd=1e-2
​learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
learn.loss_func = MyCrossEntropyFlat()
print('Loss func ', learn.loss_func)
``````

Output:
Loss func My FlattenedLoss of CrossEntropyLoss()

Exception occurs calling lr_find

``````lr_find(learn)
``````

Note that the learner is setup to use a batch size of 8, there are 256 classes, and the images have been
specified to be resized to [32,32]

The following output is captured before the exception:

``````input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([8, 256, 32, 32])
target shape  torch.Size([8, 1, 32, 32])
input shape  torch.Size([65536, 32])
target shape  torch.Size([8, 1, 32, 32])
floatify False  2d  True
kwargs  {}
Func  CrossEntropyLoss()
target view  torch.Size()
``````

(Renato Hermoza) #2

Try: `learn.loss_func = MyCrossEntropyFlat(axis=1)`, thats the channel that indicates the labels.

(David) #3

Thank you! Specifying the axis index solved the issue.