Hey all,
Had a look at this and managed to get a model training (training maybe != working )
I started with @muellerzrâs MixedDL notebook with the Kaggle Melanoma dataset. I used an upload that had 224x224 images (3gb) as the original DICOM dataset was 106gb.
I used a balanced dataset (1168 images total), as the original dataset is heavily unbalanced, to able to quickly iterate and also easily understand what the accuracy meant.
The model code is below. I took the standard fastai models from tabular_learner and cnn_learner.
For the multi-modal classifier (or âmixedâ classifier) I used PyTorch hooks to grab the inputs into the last layers of the vision and tabular models during the forward pass, and then fed these to a single linear layer.
Gradient Blending
I took the same code that @muellerzr posted earlier in this thread
Right now the weights are statically calculated, however I think getting the online weight version working might be doable, otherwise weâre left with another 3 hyperparameters to tune
Multi-Modal Model
Model code:
class TabVis(nn.Module):
def __init__(self, tab_model, vis_model, num_classes=2):
super(TabVis, self).__init__()
self.tab_model = tab_model
self.vis_model = vis_model
self.mixed_cls = nn.Linear(512+100, num_classes)
self.tab_cls = nn.Linear(100, num_classes)
self.vis_cls = nn.Linear(512, num_classes)
#self.print_handle = self.tab_model.layers[2][0].register_forward_hook(printnorm)
self.tab_handle = self.tab_model.layers[2][0].register_forward_hook(get_tab_logits)
self.vis_handle = self.vis_model[-1][-1].register_forward_hook(get_vis_logits)
def remove_my_hooks(self):
return None
def forward(self, x_cat, x_cont, x_im):
# Tabular Classifier
tab_pred = self.tab_model(x_cat, x_cont)
# Vision Classifier
vis_pred = self.vis_model(x_im)
# Logits
tab_logits = glb_tab_logits[0] # Only grabbling weights, not bias'
vis_logits = glb_vis_logits[0] # Only grabbling weights, not bias'
mixed = torch.cat((tab_logits, vis_logits), dim=1)
# Mixed Classifier
mixed_pred = self.mixed_cls(mixed)
return (tab_pred, vis_pred, mixed_pred)
PyTorch Hooks
global glb_tab_logits
def get_tab_logits(self, inp, out):
global glb_tab_logits
glb_tab_logits = inp
#return None
global glb_vis_logits
def get_vis_logits(self, inp, out):
global glb_vis_logits
glb_vis_logits = inp
#return None
Note sure of the best strategy to calculate accuracy, I calculated it for each classifier individually, and then took at weighted average of the predictions for a weighted_accuracy
def t_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[0].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def v_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[1].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def tv_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[2].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def weighted_accuracy(inp, targ, axis=-1, w_t=0.333, w_v=0.333, w_tv=0.333):
t_inp = inp[0] * w_t
v_inp = inp[1] * w_v
tv_inp = inp[2] * w_tv
inp_fin = (t_inp + v_inp + tv_inp)/3
pred,targ = flatten_check(inp_fin.argmax(dim=axis), targ)
return (pred == targ).float().mean()