I am training a unet_learner
for segmentation on 512x512 aerial images. I’m noticing that the model performs well near the center of each image where there is a lot of context and poorer near the boundary. I would like to update the loss function used to only compute loss of the central 256x256 rectangular area in the middle of each image and ignore areas near the boundary. Is it possible to do this?
Write a custom loss function. In the starting select the 256x256 pixels for the outputs and targets and then call the original loss to compute loss value for that.
Thanks @kushaj I did try that with the following modification to existing default CrossEntropyFlat function:
def CrossEntropyFlatCrop(*args, axis:int=-1, border=64, **kwargs):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
return FlattenedLossCrop(nn.CrossEntropyLoss, *args, axis=axis, border=border, **kwargs)
class FlattenedLossCrop():
"Same as `func`, but flattens input and target."
def __init__(self, func, *args, axis:int=-1, border=64, floatify:bool=False, is_2d:bool=True, **kwargs):
self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def __call__(self, input:Tensor, target:Tensor, **kwargs):
# will crop out the center here
print(input.shape)
print(target.shape)
input = input.transpose(self.axis,-1).contiguous()
target = target.transpose(self.axis,-1).contiguous()
if self.floatify: target = target.float()
input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
return self.func.__call__(input, target.view(-1), **kwargs)
...
learn.loss_func=CrossEntropyFlatCrop(axis=1, border=64)
But I get an error saying:
TypeError: ‘FlattenedLossCrop’ object is not callable
Use partial functions. By callable python means you need to give input a function, whereas you are giving an object created by a function.
partial(CrossEntropyFlatCrop, axis=1, border=64)
Thanks @kushaj that worked. Although I’m interested how the existing source code works without a partial in place. Does that get handled somewhere else?
Without a partial you get an error. The reason being it is used as a callable (function) in the source code. And it is not possible to accomplish that without a callable.