Deep Learning for Survival Prediction

Andy Reagan, Sara Saperstein, Jasmine Geng, Bisakha Peskin, and Xiangdong Gu
Saturday 08 September 2018

Introduction

Survival analysis is fundamental to the work we do here at MassMutual, underlying our algorithmic underwriting and LifeScore360 products. As a team we like to stay on the cutting edge of technologies and are always looking out for new ways to improve our models as well as learn new techniques. In February of this year a new paper was published demonstrating the first application of modern deep learning techniques to survival analysis. Katzman et al. achieved good results, outperforming Cox proportional hazards in most cases and even outperforming random survival forest in some cases with their new software, DeepSurv. We were curious how well DeepSurv would perform predicting mortality risk from lab data.

Survival modeling is different from typical machine learning approaches. Usually in supervised learning you have an array of features and one outcome variable for each instance . In survival analysis there are two outcome variables, one indicating whether the event such as death, or a failure of a part, occurred, and another indicating the time to either the event or the study end. If an event hasn't yet occurred for an instance by the time the study period ends, the event is said to be right censored. That is, the true death or failure time is obscured, but we know it will happen some time after . The goal is then to predict the cumulative hazard function (CHF), equivalent to the negative log of the survival function. The survival function is the probability that the event occurred after time , Typically a nonparametric estimator such as a Kaplan-Meier or Nelson-Aalen estimator is used to compute the survival function or cumulative hazard function, respectively.

image alt <>

Figure 1: There are many approaches for survival modeling. or a review see Machine Learning for Survival Analysis by Wang et al 2017 [2].


A standard method to model survival is the Cox proportional hazards model (CPH). CPH is a semiparametric model, meaning it doesn't require the underlying distribution of survival times, and it is popular in part for this reason. It estimates the effects of covariates on risk assuming they contribute to the log risk in a linear way. However, many covariates interact in nonlinear ways and those interactions are missed by CPH. Nonlinear approaches such as deep learning could be powerful ways to model risk when linear methods don't suffice. To our knowledge DeepSurv is the first deep learning approach to survival, building off the Faraggi-Simon method which used a single hidden layer but did not outperform Cox proportional hazards.

Another popular method to model survival is random survival forests (RSFs). RSFs draw bootstrap samples from the original data and then grow a separate survival tree for each sample. At each node, feature bagging is performed. As in decision trees, the node is split on the variable that maximizes survival difference between all nodes. The tree is grown to full size under the condition that each terminal node have no less than a prespecified number of deaths. A CHF is computed for each tree and then averaged to get the ensemble CHF. The out-of-bag samples are then used to compute the prediction error of the ensemble CHF.

Approach

We applied the existing RSF and DeepSurv pipelines to simulated datasets as well as to MassMutual's underwriting data consisting of over 958,000 individuals who took policies at MassMutual between 1999 and 2014. These underwriting data include 54 features from the insurance application, labwork, and motor vehicle records. Our goals were to learn DeepSurv and to compare model performance to assess whether DeepSurv improved our mortality modeling.

DeepSurv is a deep feed-forward network. The loss function used in DeepSurv is the negative of the Cox partial log likelihood defined as where and are the respective event indicator and baseline data for the th observation. We performed a thorough hyperparameter optimization using grid search and experiment with several state-of-the-art techniques such as rectified linear units, dropout, learning rate scheduling and gradient descent optimization algorithms as described in [1]. For our purpose, we did not implement the recommender system suggested in [1]. We compared the accuracy of the risk, using the c-index, predicted by the CPH, RSF, and DeepSurv models.

image alt <>

Figure 2: Diagram of DeepSurv. DeepSurv is a configurable feed-forward deep neural network. The input to the network is the baseline data x. The network propagates the inputs through a number of hidden layers with weights θ. The hidden layers consist of fully-connected nonlinear activation functions followed by dropout. The final layer is a single node which performs a linear combination of the hidden features. The output of the network is taken as the predicted log-risk function ĥθ(x)."[1]


Data

The simulated datasets are described well in the Deepsurv paper [1]. For simulated data, the linear risk function used is . The non-linear function used is a Gaussian with and a scale factor of . In the experiments described in [1], the risk of failure depends only on two covariates and . The training, validation and test data sets are of sizes 4000, 1000, and 1000. Each sample is composed of ten covariates drawn from a uniform distribution. The death time is generated from an exponential Cox model .

For the MM data, we utilize the fully preprocessed data that is used to train the M3S RSF production model, M3S. The dataset consists of 54 covariates, and we load the CSV into Python and feed directly into Deepsurv.

Results

DeepSurv performed better than CPH on non-linear simulated data but not linear simulated data. After some hyperparameter tuning, we saw a c-index of 0.76 for DeepSurv compared to 0.83 for RSF and 0.81 for CPH.

Table 1a shows the performance of the different methods on a simple linear function of 10 variables, a Gaussian non-linear function and the MM data. Table 1b can be used for comparison with the results reported in the paper. We used the parameters specifiied in the paper as in Table 2 with a total number of 10,000 training iterations for the linear function and 1000 training iterations for the non-linear function.

Table 1a. Datasets and C-Index Results

image alt <>

Table 1b. Compare to the values observed in the DeepSurv paper for simulated linear and nonlinear data.

image alt <>

Table 2. Hyperparameters used in the DeepSurv experiments RELU: Rectified linear unit, SELU: Scaled exponential linear unit

image alt <>

Conclusion

With moderate hyperparameter tuning, DeepSurv did not outperform Random Survival Forest. With 54 features we had more features than some of the datasets used in the DeepSurv paper. However, deep learning is typically used for more complicated data sets such as images or text which have deeply nonlinear interactions and would otherwise require extensive feature engineering. Our reported performance results could have gotten better performance with more hyperparameter tuning, which is time intensive. The MM data comprising of application data, lab values, driving records, etc., do not have deeply complicated interactions. Deep learning networks might be worth using if we are using data with more complicated interactions such as multiple genes in genomics data.

References

  1. Katzman, Jared L. and Shaham, Uri and Cloninger, Alexander and Bates, Jonathan and Jiang, Tingting and Kluger, Yuval "DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network" BMC Medical Research MethodologyBMC series – open, inclusive and trusted201818:24

  2. Wang, Ping and Li, Yan and Reddy, Chandan K. 2017. Machine learning for survival analysis: A survey. arXiv preprint arXiv:1708.04649.