Gradient Blending (for multi-modal models) - In Progress

Looking at this again, it’s actually even easier. All the if softmax is really doing is just nn.CrossEntropyLoss, as inside of it we’re already calculating the L1, so that can be skipped. model.Scale is simply multiplying by some weights. So in theory I believe our loss function should look like so:

def ModCELoss(pred, targ, ce=True):
    pred = pred.softmax(dim=-1)
    targ = targ.flatten().long()
    if ce:
        loss = F.cross_entropy(pred, targ)
        loss = F.binary_cross_entropy_with_logits(pred, targ)
    loss = torch.mean(ce)
    return loss

class GradientBlending():
    def __init__(self, audio_weight=0.0, visual_weight=0.0, av_weight=1.0, loss_scale=1.0, use_cel=True):
        "Expects weights for each model, the combined model, and an overall scale"
        self.audio_weight = audio_weight
        self.visual_weight = visual_weight
        self.av_weight = av_weight
        self.ce =  use_cel
        self.scale = loss_scale
    def forward(self, audio_out, visual_out, av_out, targ):
        "Gathers `self.loss` for each model, weighs, then sums"
        av_loss = ModCELoss(av_out, targ, self.ce) * self.scale
        a_loss = ModCELoss(audio_out, targ, self.ce) * self.scale
        v_loss = ModCELoss(visual_out, targ, self.ce) * self.scale
        weighted_a_loss = a_loss * self.a_weight
        weighted_v_loss = v_loss * self.v_weight
        weighted_av_loss = av_loss * self.av_weight
        loss = weighted_a_loss + weighted_v_loss + weighted_av_loss
        return loss

If anyone wants to check me please tell me if I happened to miss anything :slight_smile: I need to read up on what they used for scale, will update if I find it

In regards to the other weights, they used three different datasets all with a variety of different weights tested, see below for that table:

Dataset Pre-Train Model Depth Audio Weight Visual Weight AV Weight
Kinetics400 NA R3D 50 0.014 0.630 0.356
Kinetics400 None ip-CSN 152 0.009 0.485 0.506
Kinetics400 IG-65M ip-CSN 152 0.070 0.485 0.445
AudioSet None R(2+1)D 101 0.239 0.384 0.377
EPIC-Kitchen Noun IG-65M ip-CSN 152 0.175 0.460 0.364
EPIC-Kitchen Verb IG-65M ip-CSN 152 0.524 0.247 0.229

I’ll try to get any of their datasets working (or Kaggle multi-modal datasets) and see what works