Hi dear all,
Wondering about how you guys handle more than one loss functions. (warp it in BaseLoss or using a callback)
Base on some of my initial test, it seems to me that if you add in the callback, it is a bit wired for back propagation. I don’t have a proof for this, it is solely based on print out of loss value.
In addition, if I wrap the loss function, how can I print out the value for each part at certain iteration? for example, in callback we can do
def after_loss(self)
loss_1 = self.loss
loss_2 = Loss_func_2(self.pred, self.yb)
if self.iter % 10 == 0:
print(loss_1, loss_2)
self.learn.loss = loss_1 + loss_2
But as you can see, if I warp the loss in the Baseloss, I don’t have access to self.iter
code example would be
class MyLoss(Module):
y_int = True
def __init__(self):
store_attr()
def forward(self, inp, targ):
loss_1 = Loss_func_1(inp, targ)
loss_2 = Loss_func_2(inp, targ)
return loss_1 + loss_2
BaseLoss(MyLoss, flatten=False)
Many thanks