-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: dsm/datasets.py new file: dsm/dsm_api.py new file: dsm/dsm_torch.py new file: dsm/losses.py new file: dsm/utilities.py modified: METABRIC.ipynb deleted: dsm.py deleted: dsm_loss.py deleted: dsm_utilites.py
- Loading branch information
1 parent
3a9d013
commit 115b2ce
Showing
6 changed files
with
1,120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
""" | ||
Python package `dsm` provides an API to train the Deep Survival Machines | ||
and associated models for problems in survival analysis. The underlying model | ||
is implemented in `pytorch`. | ||
What is Survival Analysis? | ||
------------------------ | ||
**Survival Analysis** involves estimating when an event of interest, \( T \) | ||
would take places given some features or covariates \( X \). In statistics | ||
and ML these scenarious are modelled as regression to estimate the conditional | ||
survival distribution, \( \mathbb{P}(T>t|X) \). As compared to typical | ||
regression problems, Survival Analysis differs in two major ways: | ||
* The Event distribution, \( T \) has positive support ie. | ||
\( T \in [0, \infty) \). | ||
* There is presence of censoring ie. a large number of instances of data are | ||
lost to follow up. | ||
Deep Survival Machines | ||
---------------------- | ||
.. figure:: https://ndownloader.figshare.com/files/25259852 | ||
:figwidth: 20 % | ||
:alt: map to buried treasure | ||
This is the caption of the figure (a simple paragraph). | ||
**Deep Survival Machines (DSM)** is a fully parametric approach to model | ||
Time-to-Event outcomes in the presence of Censoring first introduced in | ||
[\[1\]](https://arxiv.org/abs/2003.01176). | ||
In the context of Healthcare ML and Biostatistics, this is known as 'Survival | ||
Analysis'. The key idea behind Deep Survival Machines is to model the | ||
underlying event outcome distribution as a mixure of some fixed \( k \) | ||
parametric distributions. The parameters of these mixture distributions as | ||
well as the mixing weights are modelled using Neural Networks. | ||
#### Usage Example | ||
>>> from dsm import DeepSurvivalMachines | ||
>>> model = DeepSurvivalMachines() | ||
>>> model.fit() | ||
>>> model.predict_risk() | ||
Deep Recurrent Survival Machines | ||
-------------------------------- | ||
**Deep Recurrent Survival Machines (DRSM)** builds on the original **DSM** | ||
model and allows for learning of representations of the input covariates using | ||
**Recurrent Neural Networks** like **LSTMs, GRUs**. Deep Recurrent Survival | ||
Machines is a natural fit to model problems where there are time dependendent | ||
covariates. | ||
..warning:: Not Implemented Yet! | ||
Deep Convolutional Survival Machines | ||
------------------------------------ | ||
Predictive maintenance and medical imaging sometimes requires to work with | ||
image streams. Deep Convolutional Survival Machines extends **DSM** and | ||
**DRSM** to learn representations of the input image data using | ||
convolutional layers. If working with streaming data, the learnt | ||
representations are then passed through an LSTM to model temporal dependencies | ||
before determining the underlying survival distributions. | ||
..warning:: Not Implemented Yet! | ||
References | ||
---------- | ||
Please cite the following papers if you are using the `dsm` package. | ||
[1] [Deep Survival Machines: | ||
Fully Parametric Survival Regression and | ||
Representation Learning for Censored Data with Competing Risks." | ||
arXiv preprint arXiv:2003.01176 (2020)](https://arxiv.org/abs/2003.01176)</a> | ||
``` | ||
@article{nagpal2020deep, | ||
title={Deep Survival Machines: Fully Parametric Survival Regression and\ | ||
Representation Learning for Censored Data with Competing Risks}, | ||
author={Nagpal, Chirag and Li, Xinyu and Dubrawski, Artur}, | ||
journal={arXiv preprint arXiv:2003.01176}, | ||
year={2020} | ||
} | ||
``` | ||
Compatibility | ||
------------- | ||
`dsm` requires `python` 3.5+ and `pytorch` 1.1+. | ||
To evaluate performance using standard metrics | ||
`dsm` requires `scikit-survival`. | ||
Contributing | ||
------------ | ||
`dsm` is [on GitHub]. Bug reports and pull requests are welcome. | ||
[on GitHub]: https://github.com/chiragnagpal/deepsurvivalmachines | ||
License | ||
------- | ||
Copyright 2020 [Chirag Nagpal](http://cs.cmu.edu/~chiragn), | ||
[Auton Lab](http://www.autonlab.org). | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
<img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/downloads/assets/images/wordmarks-600x600-min.jpg"> | ||
<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png"> | ||
<br><br><br><br><br> | ||
<br><br><br><br><br> | ||
""" | ||
|
||
from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Chirag Nagpal, Auton Lab. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utility functions to load standard datasets to train and evaluate the | ||
Deep Survival Machines models. | ||
""" | ||
|
||
|
||
import io | ||
import pkgutil | ||
|
||
import pandas as pd | ||
import numpy as np | ||
|
||
from sklearn.impute import SimpleImputer | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
def increase_censoring(e, t, p): | ||
|
||
uncens = np.where(e == 1)[0] | ||
mask = np.random.choice([False, True], len(uncens), p=[1-p, p]) | ||
toswitch = uncens[mask] | ||
|
||
e[toswitch] = 0 | ||
t_ = t[toswitch] | ||
|
||
newt = [] | ||
for t__ in t_: | ||
newt.append(np.random.uniform(1, t__)) | ||
t[toswitch] = newt | ||
|
||
return e, t | ||
|
||
def _load_pbc_dataset(sequential): | ||
"""Helper function to load and preprocess the PBC dataset | ||
The Primary biliary cirrhosis (PBC) Dataset [1] is well known | ||
dataset for evaluating survival analysis models with time | ||
dependent covariates. | ||
Parameters | ||
---------- | ||
sequential: bool | ||
If True returns a list of np.arrays for each individual. | ||
else, returns collapsed results for each time step. To train | ||
recurrent neural models you would typically use True. | ||
References | ||
---------- | ||
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and | ||
survival analysis. Vol. 169. John Wiley & Sons, 2011. | ||
""" | ||
|
||
data = pkgutil.get_data(__name__, 'datasets/pbc2.csv') | ||
data = pd.read_csv(io.BytesIO(data)) | ||
|
||
data['histologic'] = data['histologic'].astype(str) | ||
dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly', | ||
'spiders', 'edema', 'histologic']] | ||
dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline', | ||
'SGOT', 'platelets', 'prothrombin']] | ||
age = data['age'] + data['years'] | ||
|
||
x1 = pd.get_dummies(dat_cat).values | ||
x2 = dat_num.values | ||
x3 = age.values.reshape(-1, 1) | ||
x = np.hstack([x1, x2, x3]) | ||
|
||
time = (data['years'] - data['year']).values | ||
event = data['status2'].values | ||
|
||
x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) | ||
x_ = StandardScaler().fit_transform(x) | ||
|
||
if not sequential: | ||
return x_, time, event | ||
else: | ||
x, t, e = [], [], [] | ||
for id_ in sorted(list(set(data['id']))): | ||
x.append(x_[data['id'] == id_]) | ||
t.append(time[data['id'] == id_]) | ||
e.append(event[data['id'] == id_]) | ||
return x, t, e | ||
|
||
def _load_support_dataset(): | ||
"""Helper function to load and preprocess the SUPPORT dataset. | ||
The SUPPORT Dataset comes from the Vanderbilt University study | ||
to estimate survival for seriously ill hospitalized adults [1]. | ||
Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. | ||
for the original datasource. | ||
References | ||
---------- | ||
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic | ||
model: Objective estimates of survival for seriously ill hospitalized | ||
adults. Annals of Internal Medicine 122:191-203. | ||
""" | ||
|
||
data = pkgutil.get_data(__name__, 'datasets/support2.csv') | ||
data = pd.read_csv(io.BytesIO(data)) | ||
x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', | ||
'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', | ||
'urine', 'adlp', 'adls']] | ||
|
||
catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca'] | ||
x2 = pd.get_dummies(data[catfeats]) | ||
|
||
x = np.concatenate([x1, x2], axis=1) | ||
t = data['d.time'].values | ||
e = data['death'].values | ||
|
||
x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) | ||
x = StandardScaler().fit_transform(x) | ||
|
||
remove = ~np.isnan(t) | ||
return x[remove], t[remove], e[remove] | ||
|
||
|
||
def load_dataset(dataset='SUPPORT', **kwargs): | ||
"""Helper function to load datasets to test Survival Analysis models. | ||
Parameters | ||
---------- | ||
dataset: str | ||
The choice of dataset to load. Currently implemented is 'SUPPORT'. | ||
**kwargs: dict | ||
Dataset specific keyword arguments. | ||
Returns | ||
---------- | ||
tuple: (np.ndarray, np.ndarray, np.ndarray) | ||
A tuple of the form of (x, t, e) where x, t, e are the input covariates, | ||
event times and the censoring indicators respectively. | ||
""" | ||
|
||
if dataset == 'SUPPORT': | ||
return _load_support_dataset() | ||
if dataset == 'PBC': | ||
sequential = kwargs.get('sequential', False) | ||
return _load_pbc_dataset(sequential) | ||
else: | ||
return NotImplementedError('Dataset '+dataset+' not implemented.') |
Oops, something went wrong.