Hi!
First of all, thank you for the great resource!
Have some questions on @morgan implementation. Why are you calculating weighted accuracy
? Also,
Did you try to increase the number of layers in the “mixed” classifier? If no, what would be your guess if you increase complexity in there?
I am doing some experiments with Image+Tabular with high class imbalance. I am trying to use some weighted cross entropy as explained here but I am not quite sure if this is implemented correctly, here is a code snipped:
def ModCELoss(pred, targ, weight=None, ce=True):
pred = pred.softmax(dim=-1)
targ = targ.flatten().long()
if ce:
loss = F.cross_entropy(pred, targ, weight=weight)
else:
loss = F.binary_cross_entropy_with_logits(pred, targ, weight=weight)
#loss = torch.mean(ce)
return loss
class myGradientBlending(nn.Module):
def __init__(self, tab_weight=0.0, visual_weight=0.0, tab_vis_weight=1.0, loss_scale=1.0, weight=None, use_cel=True):
"Expects weights for each model, the combined model, and an overall scale"
super(myGradientBlending, self).__init__()
self.tab_weight = tab_weight
self.visual_weight = visual_weight
self.tab_vis_weight = tab_vis_weight
self.ce = use_cel
self.scale = loss_scale
self.weight = weight
def forward(self, xb, yb):
tab_out, visual_out, tv_out = xb
targ = yb
"Gathers `self.loss` for each model, weighs, then sums"
tv_loss = ModCELoss(tv_out, targ, self.weight, self.ce) * self.scale
t_loss = ModCELoss(tab_out, targ, self.weight, self.ce) * self.scale
v_loss = ModCELoss(visual_out, targ, self.weight, self.ce) * self.scale
weighted_t_loss = t_loss * self.tab_weight
weighted_v_loss = v_loss * self.visual_weight
weighted_tv_loss = tv_loss * self.tab_vis_weight
loss = weighted_t_loss + weighted_v_loss + weighted_tv_loss
return loss
Where I added a new weigth
value in both ModCELoss
and myGradientBlending
class. It seems to work well but I would like to tell me what you think.
Finally, how do you assess the proper weights for each loss? I do not quite understand this part, could you develop a little bit ?
EDIT: And just one more, any hints on how ClassificationInterpretation
should be modified in order to make it work in this kind of model?
Thanks!