Survival analysis in fastai

Hi all,

I have been working on some deep learning models using torch and pycox as a wrapper to calculate coxPH loss to deal with right truncated clinical data properly.

The trivial example is:
import torch
import torchtuples as tt
net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
dropout, output_bias=output_bias)
model = CoxPH(net, tt.optim.Adam)

I would like to use this coxPH based loss in the fastai framework together with the existing functionality for TabularData. Has anyone experience implementing survival analysis in fastai? This would also make a lot of sense to exist in the medical branch of fastai.

There are a couple of deep learning models for survival for example DeepSurv, which was build with Theano/Lasagne that implement this kind of functionality.

Thanks a lot and best regards

Ok, I did some more experimentation using the metabric dataset from the pycox example and got the loss function working (I think).

Here is the original example from pycox link
If you want to reproduce what I am doing just run the pycox example and my code parts directly afterwards.

Here is the version of the loss function I have adapted from the DeepSurv paper:

def negative_log_likelihood(risk_pred, *y):
    
    #unpacking X, durations and event values
    risk_pred = risk_pred[:,0]
    targets = y[0][:,0]
    e = y[0][:,1]
    
    #Sorting by durations for partial likelihood prediction
    sort_idx = torch.argsort(targets)
    risk_pred = risk_pred[sort_idx]
    e = e[sort_idx]
    targets = targets[sort_idx]
    
    #Implementing negative log partial likelihood
    hazard_ratio = torch.exp(risk_pred)
    log_risk = torch.log(torch.cumsum(hazard_ratio, axis=0))
    uncensored_likelihood = risk_pred.T - log_risk
    censored_likelihood = uncensored_likelihood * e
    num_observed_events = torch.sum(e)
    neg_likelihood = -torch.sum(censored_likelihood) / num_observed_events
    return neg_likelihood

The first problem was to get the duration and event into the loss_func. I found an older post from @sgugger , but that did only half of the trick for me. Providing two outcomes to the tabular learner did not allow for dynamic unpacking in the way loss_func(x, durations, events), but loss_func(risk_pred, *y) with unpacking in the function did work.

This is how I set up the TabularLearner:

from fastai.tabular.all import *

to = TabularDataLoaders.from_df(
    df_train,
    procs=[Normalize],
    cont_names=cont_vars,
    y_names=["duration", "event"],
    y_block = RegressionBlock(),
    valid_idx=list(range(800,1218))
)

learn = tabular_learner(to, lr=1e-03, loss_func=negative_log_likelihood)

learn.lr_find()
learn.fit_one_cycle(100)

Researching for existing solutions I also found some tweets from @jeremy from early 2020 about survival models in the context of covid19. Maybe he has done some work on this that I have missed?

On my todo list now is:

  1. to get a concordance-index metric like the one from lifelines (from lifelines.utils import concordance_index)
  2. Figure out if I still need to perform an L2-regression. In the paper it is done, but if I got it correctly fastai does this automatically when using ADAM?