Ok, I did some more experimentation using the metabric dataset from the pycox example and got the loss function working (I think).
Here is the original example from pycox link
If you want to reproduce what I am doing just run the pycox example and my code parts directly afterwards.
Here is the version of the loss function I have adapted from the DeepSurv paper:
def negative_log_likelihood(risk_pred, *y):
#unpacking X, durations and event values
risk_pred = risk_pred[:,0]
targets = y[:,0]
e = y[:,1]
#Sorting by durations for partial likelihood prediction
sort_idx = torch.argsort(targets)
risk_pred = risk_pred[sort_idx]
e = e[sort_idx]
targets = targets[sort_idx]
#Implementing negative log partial likelihood
hazard_ratio = torch.exp(risk_pred)
log_risk = torch.log(torch.cumsum(hazard_ratio, axis=0))
uncensored_likelihood = risk_pred.T - log_risk
censored_likelihood = uncensored_likelihood * e
num_observed_events = torch.sum(e)
neg_likelihood = -torch.sum(censored_likelihood) / num_observed_events
The first problem was to get the duration and event into the loss_func. I found an older post from @sgugger , but that did only half of the trick for me. Providing two outcomes to the tabular learner did not allow for dynamic unpacking in the way loss_func(x, durations, events), but loss_func(risk_pred, *y) with unpacking in the function did work.
This is how I set up the TabularLearner:
from fastai.tabular.all import *
to = TabularDataLoaders.from_df(
y_block = RegressionBlock(),
learn = tabular_learner(to, lr=1e-03, loss_func=negative_log_likelihood)
Researching for existing solutions I also found some tweets from @jeremy from early 2020 about survival models in the context of covid19. Maybe he has done some work on this that I have missed?
On my todo list now is:
- to get a concordance-index metric like the one from lifelines (from lifelines.utils import concordance_index)
- Figure out if I still need to perform an L2-regression. In the paper it is done, but if I got it correctly fastai does this automatically when using ADAM?