Skip to content

Commit

Permalink
new file: dsm/__init__.py
Browse files Browse the repository at this point in the history
	new file:   dsm/datasets.py
	new file:   dsm/dsm_api.py
	new file:   dsm/dsm_torch.py
	new file:   dsm/losses.py
	new file:   dsm/utilities.py
	modified:   METABRIC.ipynb
	deleted:    dsm.py
	deleted:    dsm_loss.py
	deleted:    dsm_utilites.py
  • Loading branch information
chiragnagpal committed Oct 26, 2020
1 parent 3a9d013 commit 115b2ce
Show file tree
Hide file tree
Showing 6 changed files with 1,120 additions and 0 deletions.
128 changes: 128 additions & 0 deletions dsm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Python package `dsm` provides an API to train the Deep Survival Machines
and associated models for problems in survival analysis. The underlying model
is implemented in `pytorch`.
What is Survival Analysis?
------------------------
**Survival Analysis** involves estimating when an event of interest, \( T \)
would take places given some features or covariates \( X \). In statistics
and ML these scenarious are modelled as regression to estimate the conditional
survival distribution, \( \mathbb{P}(T>t|X) \). As compared to typical
regression problems, Survival Analysis differs in two major ways:
* The Event distribution, \( T \) has positive support ie.
\( T \in [0, \infty) \).
* There is presence of censoring ie. a large number of instances of data are
lost to follow up.
Deep Survival Machines
----------------------
.. figure:: https://ndownloader.figshare.com/files/25259852
:figwidth: 20 %
:alt: map to buried treasure
This is the caption of the figure (a simple paragraph).
**Deep Survival Machines (DSM)** is a fully parametric approach to model
Time-to-Event outcomes in the presence of Censoring first introduced in
[\[1\]](https://arxiv.org/abs/2003.01176).
In the context of Healthcare ML and Biostatistics, this is known as 'Survival
Analysis'. The key idea behind Deep Survival Machines is to model the
underlying event outcome distribution as a mixure of some fixed \( k \)
parametric distributions. The parameters of these mixture distributions as
well as the mixing weights are modelled using Neural Networks.
#### Usage Example
>>> from dsm import DeepSurvivalMachines
>>> model = DeepSurvivalMachines()
>>> model.fit()
>>> model.predict_risk()
Deep Recurrent Survival Machines
--------------------------------
**Deep Recurrent Survival Machines (DRSM)** builds on the original **DSM**
model and allows for learning of representations of the input covariates using
**Recurrent Neural Networks** like **LSTMs, GRUs**. Deep Recurrent Survival
Machines is a natural fit to model problems where there are time dependendent
covariates.
..warning:: Not Implemented Yet!
Deep Convolutional Survival Machines
------------------------------------
Predictive maintenance and medical imaging sometimes requires to work with
image streams. Deep Convolutional Survival Machines extends **DSM** and
**DRSM** to learn representations of the input image data using
convolutional layers. If working with streaming data, the learnt
representations are then passed through an LSTM to model temporal dependencies
before determining the underlying survival distributions.
..warning:: Not Implemented Yet!
References
----------
Please cite the following papers if you are using the `dsm` package.
[1] [Deep Survival Machines:
Fully Parametric Survival Regression and
Representation Learning for Censored Data with Competing Risks."
arXiv preprint arXiv:2003.01176 (2020)](https://arxiv.org/abs/2003.01176)</a>
```
@article{nagpal2020deep,
title={Deep Survival Machines: Fully Parametric Survival Regression and\
Representation Learning for Censored Data with Competing Risks},
author={Nagpal, Chirag and Li, Xinyu and Dubrawski, Artur},
journal={arXiv preprint arXiv:2003.01176},
year={2020}
}
```
Compatibility
-------------
`dsm` requires `python` 3.5+ and `pytorch` 1.1+.
To evaluate performance using standard metrics
`dsm` requires `scikit-survival`.
Contributing
------------
`dsm` is [on GitHub]. Bug reports and pull requests are welcome.
[on GitHub]: https://github.com/chiragnagpal/deepsurvivalmachines
License
-------
Copyright 2020 [Chirag Nagpal](http://cs.cmu.edu/~chiragn),
[Auton Lab](http://www.autonlab.org).
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/downloads/assets/images/wordmarks-600x600-min.jpg">
<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png">
<br><br><br><br><br>
<br><br><br><br><br>
"""

from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines
160 changes: 160 additions & 0 deletions dsm/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# coding=utf-8
# Copyright 2020 Chirag Nagpal, Auton Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions to load standard datasets to train and evaluate the
Deep Survival Machines models.
"""


import io
import pkgutil

import pandas as pd
import numpy as np

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

def increase_censoring(e, t, p):

uncens = np.where(e == 1)[0]
mask = np.random.choice([False, True], len(uncens), p=[1-p, p])
toswitch = uncens[mask]

e[toswitch] = 0
t_ = t[toswitch]

newt = []
for t__ in t_:
newt.append(np.random.uniform(1, t__))
t[toswitch] = newt

return e, t

def _load_pbc_dataset(sequential):
"""Helper function to load and preprocess the PBC dataset
The Primary biliary cirrhosis (PBC) Dataset [1] is well known
dataset for evaluating survival analysis models with time
dependent covariates.
Parameters
----------
sequential: bool
If True returns a list of np.arrays for each individual.
else, returns collapsed results for each time step. To train
recurrent neural models you would typically use True.
References
----------
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and
survival analysis. Vol. 169. John Wiley & Sons, 2011.
"""

data = pkgutil.get_data(__name__, 'datasets/pbc2.csv')
data = pd.read_csv(io.BytesIO(data))

data['histologic'] = data['histologic'].astype(str)
dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly',
'spiders', 'edema', 'histologic']]
dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline',
'SGOT', 'platelets', 'prothrombin']]
age = data['age'] + data['years']

x1 = pd.get_dummies(dat_cat).values
x2 = dat_num.values
x3 = age.values.reshape(-1, 1)
x = np.hstack([x1, x2, x3])

time = (data['years'] - data['year']).values
event = data['status2'].values

x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
x_ = StandardScaler().fit_transform(x)

if not sequential:
return x_, time, event
else:
x, t, e = [], [], []
for id_ in sorted(list(set(data['id']))):
x.append(x_[data['id'] == id_])
t.append(time[data['id'] == id_])
e.append(event[data['id'] == id_])
return x, t, e

def _load_support_dataset():
"""Helper function to load and preprocess the SUPPORT dataset.
The SUPPORT Dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults [1].
Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.
References
----------
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
model: Objective estimates of survival for seriously ill hospitalized
adults. Annals of Internal Medicine 122:191-203.
"""

data = pkgutil.get_data(__name__, 'datasets/support2.csv')
data = pd.read_csv(io.BytesIO(data))
x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
'urine', 'adlp', 'adls']]

catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
x2 = pd.get_dummies(data[catfeats])

x = np.concatenate([x1, x2], axis=1)
t = data['d.time'].values
e = data['death'].values

x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
x = StandardScaler().fit_transform(x)

remove = ~np.isnan(t)
return x[remove], t[remove], e[remove]


def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Parameters
----------
dataset: str
The choice of dataset to load. Currently implemented is 'SUPPORT'.
**kwargs: dict
Dataset specific keyword arguments.
Returns
----------
tuple: (np.ndarray, np.ndarray, np.ndarray)
A tuple of the form of (x, t, e) where x, t, e are the input covariates,
event times and the censoring indicators respectively.
"""

if dataset == 'SUPPORT':
return _load_support_dataset()
if dataset == 'PBC':
sequential = kwargs.get('sequential', False)
return _load_pbc_dataset(sequential)
else:
return NotImplementedError('Dataset '+dataset+' not implemented.')
Loading

0 comments on commit 115b2ce

Please sign in to comment.