I hadn’t seen that yet actually so thank you. It is helpful but doesn’t looks like it solves my issue. I can define a working function, the problem is when i try to assign it to my model.crit= MY_NEW_FUNCTION is through an error looking for inputs into the function that aren’t calculated yet.
def weighted_mse_loss(input, target):
weight = torch.ones(16)
weight[0] = 3
return torch.mean(weight * (input - target) ** 2)
m.crit=weighted_mse_loss()
gives me this error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-79-c956b86e8f4f> in <module>()
4 return torch.mean(weight * (input - target) ** 2)
5
----> 6 m.crit=weighted_mse_loss()
TypeError: weighted_mse_loss() missing 2 required positional arguments: 'input' and 'target'