Gradient Blending (for multi-modal models) - In Progress

Hey all,

Had a look at this and managed to get a model training (training maybe != working :sweat_smile:)

NOTEBOOK HERE

Data

I started with @muellerzr’s MixedDL notebook with the Kaggle Melanoma dataset. I used an upload that had 224x224 images (3gb) as the original DICOM dataset was 106gb.

I used a balanced dataset (1168 images total), as the original dataset is heavily unbalanced, to able to quickly iterate and also easily understand what the accuracy meant.

Model

The model code is below. I took the standard fastai models from tabular_learner and cnn_learner.

For the multi-modal classifier (or “mixed” classifier) I used PyTorch hooks to grab the inputs into the last layers of the vision and tabular models during the forward pass, and then fed these to a single linear layer.

Gradient Blending

I took the same code that @muellerzr posted earlier in this thread

Weights

Right now the weights are statically calculated, however I think getting the online weight version working might be doable, otherwise we’re left with another 3 hyperparameters to tune :no_mouth:

Multi-Modal Model

Model code:

class TabVis(nn.Module):
    def __init__(self, tab_model, vis_model, num_classes=2): 
        super(TabVis, self).__init__()
        self.tab_model = tab_model
        self.vis_model = vis_model
        
        self.mixed_cls = nn.Linear(512+100, num_classes)
        self.tab_cls = nn.Linear(100, num_classes)
        self.vis_cls = nn.Linear(512, num_classes)
        
        #self.print_handle = self.tab_model.layers[2][0].register_forward_hook(printnorm)
        self.tab_handle = self.tab_model.layers[2][0].register_forward_hook(get_tab_logits)
        self.vis_handle = self.vis_model[-1][-1].register_forward_hook(get_vis_logits)
    
    def remove_my_hooks(self):
        self.tab_handle.remove()
        self.vis_handle.remove()
        #self.print_handle.remove()
        return None
        
    def forward(self, x_cat, x_cont, x_im):
        # Tabular Classifier
        tab_pred = self.tab_model(x_cat, x_cont)        
        
        # Vision Classifier
        vis_pred = self.vis_model(x_im)
        
        # Logits
        tab_logits = glb_tab_logits[0]   # Only grabbling weights, not bias'
        vis_logits = glb_vis_logits[0]   # Only grabbling weights, not bias'
        mixed = torch.cat((tab_logits, vis_logits), dim=1)
        
        # Mixed Classifier
        mixed_pred = self.mixed_cls(mixed)
        return (tab_pred, vis_pred, mixed_pred)

PyTorch Hooks

global glb_tab_logits
def get_tab_logits(self, inp, out):
    global glb_tab_logits
    glb_tab_logits = inp
    #return None

global glb_vis_logits
def get_vis_logits(self, inp, out):
    global glb_vis_logits
    glb_vis_logits = inp
    #return None

Metrics

Note sure of the best strategy to calculate accuracy, I calculated it for each classifier individually, and then took at weighted average of the predictions for a weighted_accuracy

def t_accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp[0].argmax(dim=axis), targ)
    return (pred == targ).float().mean()

def v_accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp[1].argmax(dim=axis), targ)
    return (pred == targ).float().mean()

def tv_accuracy(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp[2].argmax(dim=axis), targ)
    return (pred == targ).float().mean()

def weighted_accuracy(inp, targ, axis=-1, w_t=0.333, w_v=0.333, w_tv=0.333):
    t_inp = inp[0] * w_t
    v_inp = inp[1] * w_v
    tv_inp = inp[2] * w_tv
    inp_fin = (t_inp + v_inp + tv_inp)/3
    pred,targ = flatten_check(inp_fin.argmax(dim=axis), targ)
    return (pred == targ).float().mean()
11 Likes