Hey all,
Had a look at this and managed to get a model training (training maybe != working )
NOTEBOOK HERE
Data
I started with @muellerzrâs MixedDL notebook with the Kaggle Melanoma dataset. I used an upload that had 224x224 images (3gb) as the original DICOM dataset was 106gb.
I used a balanced dataset (1168 images total), as the original dataset is heavily unbalanced, to able to quickly iterate and also easily understand what the accuracy meant.
Model
The model code is below. I took the standard fastai models from tabular_learner and cnn_learner.
For the multi-modal classifier (or âmixedâ classifier) I used PyTorch hooks to grab the inputs into the last layers of the vision and tabular models during the forward pass, and then fed these to a single linear layer.
Gradient Blending
I took the same code that @muellerzr posted earlier in this thread
Weights
Right now the weights are statically calculated, however I think getting the online weight version working might be doable, otherwise weâre left with another 3 hyperparameters to tune
Multi-Modal Model
Model code:
class TabVis(nn.Module):
def __init__(self, tab_model, vis_model, num_classes=2):
super(TabVis, self).__init__()
self.tab_model = tab_model
self.vis_model = vis_model
self.mixed_cls = nn.Linear(512+100, num_classes)
self.tab_cls = nn.Linear(100, num_classes)
self.vis_cls = nn.Linear(512, num_classes)
#self.print_handle = self.tab_model.layers[2][0].register_forward_hook(printnorm)
self.tab_handle = self.tab_model.layers[2][0].register_forward_hook(get_tab_logits)
self.vis_handle = self.vis_model[-1][-1].register_forward_hook(get_vis_logits)
def remove_my_hooks(self):
self.tab_handle.remove()
self.vis_handle.remove()
#self.print_handle.remove()
return None
def forward(self, x_cat, x_cont, x_im):
# Tabular Classifier
tab_pred = self.tab_model(x_cat, x_cont)
# Vision Classifier
vis_pred = self.vis_model(x_im)
# Logits
tab_logits = glb_tab_logits[0] # Only grabbling weights, not bias'
vis_logits = glb_vis_logits[0] # Only grabbling weights, not bias'
mixed = torch.cat((tab_logits, vis_logits), dim=1)
# Mixed Classifier
mixed_pred = self.mixed_cls(mixed)
return (tab_pred, vis_pred, mixed_pred)
PyTorch Hooks
global glb_tab_logits
def get_tab_logits(self, inp, out):
global glb_tab_logits
glb_tab_logits = inp
#return None
global glb_vis_logits
def get_vis_logits(self, inp, out):
global glb_vis_logits
glb_vis_logits = inp
#return None
Metrics
Note sure of the best strategy to calculate accuracy, I calculated it for each classifier individually, and then took at weighted average of the predictions for a weighted_accuracy
def t_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[0].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def v_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[1].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def tv_accuracy(inp, targ, axis=-1):
"Compute accuracy with `targ` when `pred` is bs * n_classes"
pred,targ = flatten_check(inp[2].argmax(dim=axis), targ)
return (pred == targ).float().mean()
def weighted_accuracy(inp, targ, axis=-1, w_t=0.333, w_v=0.333, w_tv=0.333):
t_inp = inp[0] * w_t
v_inp = inp[1] * w_v
tv_inp = inp[2] * w_tv
inp_fin = (t_inp + v_inp + tv_inp)/3
pred,targ = flatten_check(inp_fin.argmax(dim=axis), targ)
return (pred == targ).float().mean()