I am trying to create and use a custom loss function. When my initial attempts failed I decided to take a step back and implement (through cut and paste) the standard loss function used with a unet Learner in my own notebook. I thought this would be a good way to check my understanding of the size of the tensor inputs and see where the inputs differed between the standard loss function and the ones I first created.
To my disappointment my “cut and paste” loss function also does not work in that an exception is thrown during lr_find.
/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1786 if input.size(0) != target.size(0):
1787 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1788 .format(input.size(0), target.size(0)))
1789 if dim == 2:
1790 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (65536) to match target batch_size (8192).
I would appreciate some insight into what I am doing wrong.
Initial standard fastai code which does work:
wd=1e-2
learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
print('Loss func ', learn.loss_func)
Output:
Loss func FlattenedLoss of CrossEntropyLoss()
Here is the code I’ve pasted in (and renamed) that fails.
class MyFlattenedLoss():
"Same as `func`, but flattens input and target."
def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d
def __repr__(self): return f"My FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
print('input shape ', input.shape)
print('target shape ', target.shape)
input = input.transpose(self.axis,-1).contiguous()
target = target.transpose(self.axis,-1).contiguous()
print('input shape ', input.shape)
print('target shape ', target.shape)
if self.floatify: target = target.float()
input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
print('input shape ', input.shape)
print('target shape ', target.shape)
print('floatify', self.floatify, ' 2d ', self.is_2d)
print('kwargs ', kwargs)
print('Func ', self.func)
print('target view ', target.view(-1).shape)
return self.func.__call__(input, target.view(-1), **kwargs)
def MyCrossEntropyFlat(*args, axis:int=-1, **kwargs):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
return MyFlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
wd=1e-2
learn = unet_learner(data, models.resnet34, metrics=[], wd=wd)
learn.loss_func = MyCrossEntropyFlat()
print('Loss func ', learn.loss_func)
Output:
Loss func My FlattenedLoss of CrossEntropyLoss()
Exception occurs calling lr_find
lr_find(learn)
Note that the learner is setup to use a batch size of 8, there are 256 classes, and the images have been
specified to be resized to [32,32]
The following output is captured before the exception:
input shape torch.Size([8, 256, 32, 32])
target shape torch.Size([8, 1, 32, 32])
input shape torch.Size([8, 256, 32, 32])
target shape torch.Size([8, 1, 32, 32])
input shape torch.Size([65536, 32])
target shape torch.Size([8, 1, 32, 32])
floatify False 2d True
kwargs {}
Func CrossEntropyLoss()
target view torch.Size([8192])