Looking at this again, it’s actually even easier. All the if softmax
is really doing is just nn.CrossEntropyLoss
, as inside of it we’re already calculating the L1, so that can be skipped. model.Scale
is simply multiplying by some weights. So in theory I believe our loss function should look like so:
def ModCELoss(pred, targ, ce=True):
pred = pred.softmax(dim=-1)
targ = targ.flatten().long()
if ce:
loss = F.cross_entropy(pred, targ)
else:
loss = F.binary_cross_entropy_with_logits(pred, targ)
loss = torch.mean(ce)
return loss
class GradientBlending():
def __init__(self, audio_weight=0.0, visual_weight=0.0, av_weight=1.0, loss_scale=1.0, use_cel=True):
"Expects weights for each model, the combined model, and an overall scale"
self.audio_weight = audio_weight
self.visual_weight = visual_weight
self.av_weight = av_weight
self.ce = use_cel
self.scale = loss_scale
def forward(self, audio_out, visual_out, av_out, targ):
"Gathers `self.loss` for each model, weighs, then sums"
av_loss = ModCELoss(av_out, targ, self.ce) * self.scale
a_loss = ModCELoss(audio_out, targ, self.ce) * self.scale
v_loss = ModCELoss(visual_out, targ, self.ce) * self.scale
weighted_a_loss = a_loss * self.a_weight
weighted_v_loss = v_loss * self.v_weight
weighted_av_loss = av_loss * self.av_weight
loss = weighted_a_loss + weighted_v_loss + weighted_av_loss
return loss
If anyone wants to check me please tell me if I happened to miss anything I need to read up on what they used for
scale
, will update if I find it
In regards to the other weights, they used three different datasets all with a variety of different weights tested, see below for that table:
Dataset | Pre-Train | Model | Depth | Audio Weight | Visual Weight | AV Weight |
---|---|---|---|---|---|---|
Kinetics400 | NA | R3D | 50 | 0.014 | 0.630 | 0.356 |
Kinetics400 | None | ip-CSN | 152 | 0.009 | 0.485 | 0.506 |
Kinetics400 | IG-65M | ip-CSN | 152 | 0.070 | 0.485 | 0.445 |
AudioSet | None | R(2+1)D | 101 | 0.239 | 0.384 | 0.377 |
EPIC-Kitchen Noun | IG-65M | ip-CSN | 152 | 0.175 | 0.460 | 0.364 |
EPIC-Kitchen Verb | IG-65M | ip-CSN | 152 | 0.524 | 0.247 | 0.229 |
I’ll try to get any of their datasets working (or Kaggle multi-modal datasets) and see what works