How do you change validation loss_function?

I am training a unet_learner for segmentation on 512x512 aerial images. I’m noticing that the model performs well near the center of each image where there is a lot of context and poorer near the boundary. I would like to update the loss function used to only compute loss of the central 256x256 rectangular area in the middle of each image and ignore areas near the boundary. Is it possible to do this?

1 Like

Write a custom loss function. In the starting select the 256x256 pixels for the outputs and targets and then call the original loss to compute loss value for that.

Thanks @kushaj I did try that with the following modification to existing default CrossEntropyFlat function:

def CrossEntropyFlatCrop(*args, axis:int=-1, border=64, **kwargs):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    return FlattenedLossCrop(nn.CrossEntropyLoss, *args, axis=axis, border=border, **kwargs)

class FlattenedLossCrop():
    "Same as `func`, but flattens input and target."
    def __init__(self, func, *args, axis:int=-1, border=64, floatify:bool=False, is_2d:bool=True, **kwargs):
        self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d
        functools.update_wrapper(self, self.func)

    def __repr__(self): return f"FlattenedLoss of {self.func}"
    @property
    def reduction(self): return self.func.reduction
    @reduction.setter
    def reduction(self, v): self.func.reduction = v

    def __call__(self, input:Tensor, target:Tensor, **kwargs):

        # will crop out the center here
        print(input.shape)
        print(target.shape)

        input = input.transpose(self.axis,-1).contiguous()
        target = target.transpose(self.axis,-1).contiguous()
        if self.floatify: target = target.float()
        input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
        return self.func.__call__(input, target.view(-1), **kwargs)

...

learn.loss_func=CrossEntropyFlatCrop(axis=1, border=64)

But I get an error saying:

TypeError: ‘FlattenedLossCrop’ object is not callable

Use partial functions. By callable python means you need to give input a function, whereas you are giving an object created by a function.
partial(CrossEntropyFlatCrop, axis=1, border=64)

Thanks @kushaj that worked. Although I’m interested how the existing source code works without a partial in place. Does that get handled somewhere else?

Without a partial you get an error. The reason being it is used as a callable (function) in the source code. And it is not possible to accomplish that without a callable.