Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   datasets.py
	modified:   datasets_biolincc.py
	modified:   experiments.py
	modified:   explainers.py
	modified:   metrics.py
  • Loading branch information
chiragnagpal committed Feb 16, 2022
1 parent e65ced8 commit 77a615a
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 76 deletions.
78 changes: 73 additions & 5 deletions auton_survival/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,32 @@
'''
[![Build Status](https://travis-ci.org/autonlab/DeepSurvivalMachines.svg?branch=master)](https://travis-ci.org/autonlab/DeepSurvivalMachines)
   
[![codecov](https://codecov.io/gh/autonlab/DeepSurvivalMachines/branch/master/graph/badge.svg?token=FU1HB5O92D)](https://codecov.io/gh/autonlab/DeepSurvivalMachines)
   
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
   
[![GitHub Repo stars](https://img.shields.io/github/stars/autonlab/DeepSurvivalMachines?style=social)](https://github.com/autonlab/DeepSurvivalMachines)
Python package `auton_survival` provides a flexible API for various problems
in survival analysis, including regression, counterfactual estimation,
and phenotyping.
What is Survival Analysis?
------------------------
**Survival Analysis** involves estimating when an event of interest, \( T \)
would take places given some features or covariates \( X \). In statistics
and ML these scenarious are modelled as regression to estimate the conditional
survival distribution, \( \mathbb{P}(T>t|X) \). As compared to typical
regression problems, Survival Analysis differs in two major ways:
* The Event distribution, \( T \) has positive support ie.
\( T \in [0, \infty) \).
* There is presence of censoring ie. a large number of instances of data are
lost to follow up.
# Auton Survival
Repository of reusable code utilities for Survival Analysis projects.
Expand Down Expand Up @@ -90,9 +118,6 @@ class is a composite transform that does both Imputing ***and*** Scaling.
print(scores)
```
## `auton_survival.reporting`
Helper functions to generate standard reports for popular Survival Analysis problems.
Expand All @@ -104,9 +129,52 @@ class is a composite transform that does both Imputing ***and*** Scaling.
foo@bar:~$ pip install -r requirements.txt
```
## Requirements
Compatibility
-------------
`dsm` requires `python` 3.5+ and `pytorch` 1.1+.
To evaluate performance using standard metrics
`dsm` requires `scikit-survival`.
Contributing
------------
`dsm` is [on GitHub]. Bug reports and pull requests are welcome.
[on GitHub]: https://github.com/chiragnagpal/deepsurvivalmachines
License
-------
MIT License
Copyright (c) 2020 Carnegie Mellon University, [Auton Lab](http://autonlab.org)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
<img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/\
downloads/assets/images/wordmarks-600x600-min.jpg">
<img style="float: right;padding-top:50px" src="https://www.autonlab.org/\
user/themes/auton/images/AutonLogo.png">
<br><br><br><br><br>
<br><br><br><br><br>
`scikit-learn`, `scikit-survival`, `lifelines`, `matplotlib`, `pandas`, `numpy`, `missingpy`
'''

from .models.dsm import DeepSurvivalMachines
Expand Down
22 changes: 11 additions & 11 deletions auton_survival/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def load_support():
to estimate survival for seriously ill hospitalized adults [1].
Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.
References
----------
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
Expand All @@ -175,8 +176,8 @@ def load_support():
outcomes = outcomes[['event', 'time']]

cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
'glucose', 'bun', 'urine', 'adlp', 'adls']

return outcomes, data[cat_feats+num_feats]
Expand Down Expand Up @@ -255,19 +256,18 @@ def load_synthetic_cf_phenotyping():

def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Currently implemented datasets include:
Currently implemented datasets include:\n
**SUPPORT**: This dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults [1].
(Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.)
for the original datasource.)\n
**PBC**: The Primary biliary cirrhosis dataset [2] is well known
dataset for evaluating survival analysis models with time
dependent covariates.
dependent covariates.\n
**FRAMINGHAM**: This dataset is a subset of 4,434 participants of the well
known, ongoing Framingham Heart study [3] for studying epidemiology for
hypertensive and arteriosclerotic cardiovascular disease. It is a popular
dataset for longitudinal survival analysis with time dependent covariates.
References
dataset for longitudinal survival analysis with time dependent covariates.\n
**SYNTHETIC**: This is a non-linear censored dataset for counterfactual
time-to-event phenotyping. Introduced in [4], the dataset is generated
such that the treatment effect is heterogenous conditioned on the covariates.
Expand All @@ -276,16 +276,16 @@ def load_dataset(dataset='SUPPORT', **kwargs):
-----------
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
model: Objective estimates of survival for seriously ill hospitalized
adults. Annals of Internal Medicine 122:191-203.
adults. Annals of Internal Medicine 122:191-203.\n
[2] Fleming, Thomas R., and David P. Harrington. Counting processes and
survival analysis. Vol. 169. John Wiley & Sons, 2011.
survival analysis. Vol. 169. John Wiley & Sons, 2011.\n
[3] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr.
"Epidemiological approaches to heart disease: the Framingham Study."
American Journal of Public Health and the Nations Health 41.3 (1951).
Parameters
American Journal of Public Health and the Nations Health 41.3 (1951).\n
[4] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski.
"Counterfactual phenotyping for censored Time-to-Events" (2022).
Parameters
----------
dataset: str
The choice of dataset to load. Currently implemented is 'SUPPORT',
Expand Down
36 changes: 18 additions & 18 deletions auton_survival/datasets_biolincc.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ def load_proud(endpoint=None, features=None, location=''):

print("No Features Specified!! using default baseline features.")

categorical_features = ['RZGROUP', 'RACE', 'HISPANIC', 'ETHNIC',
'SEX', 'ESTROGEN', 'BLMEDS', 'MISTROKE',
'HXCABG', 'STDEPR', 'OASCVD', 'DIABETES',
'HDLLT35', 'LVHECG', 'WALL25', 'LCHD',
'CURSMOKE', 'ASPIRIN', 'LLT', 'RACE2',
categorical_features = ['RZGROUP', 'RACE', 'HISPANIC', 'ETHNIC',
'SEX', 'ESTROGEN', 'BLMEDS', 'MISTROKE',
'HXCABG', 'STDEPR', 'OASCVD', 'DIABETES',
'HDLLT35', 'LVHECG', 'WALL25', 'LCHD',
'CURSMOKE', 'ASPIRIN', 'LLT', 'RACE2',
'BLMEDS2', 'GEOREGN']

numeric_features = ['AGE', 'BLWGT', 'BLHGT', 'BLBMI', 'BV2SBP',
Expand Down Expand Up @@ -333,28 +333,28 @@ def load_sprint_pop():
def load_stich():
raise NotImplementedError()


def load_accord(endpoint=None, features=None, location=''):

# Default Baseline Features to include:
def load_accord(endpoint=None, features=None, location=''):

# Default Baseline Features to include:
if features is None:

print("No Features Specified!! using default baseline features.")

features = {

'ACCORD/3-Data Sets - Analysis/3a-Analysis Data Sets/accord_key.sas7bdat': ['female', 'baseline_age', 'arm',
'cvd_hx_baseline', 'raceclass',
'treatment'],

'ACCORD/3-Data Sets - Analysis/3a-Analysis Data Sets/bloodpressure.sas7bdat': ['sbp', 'dbp', 'hr'],

'ACCORD/4-Data Sets - CRFs/4a-CRF Data Sets/f01_inclusionexclusionsummary.sas7bdat': ['x1diab', 'x2mi',
'x2stroke', 'x2angina','cabg','ptci','cvdhist','orevasc','x2hbac11','x2hbac9','x3malb','x3lvh','x3sten','x4llmeds',
'x4gender','x4hdlf', 'x4hdlm','x4bpmeds','x4notmed','x4smoke','x4bmi'],

'ACCORD/3-Data Sets - Analysis/3a-Analysis Data Sets/lipids.sas7bdat': ['chol', 'trig', 'vldl', 'ldl', 'hdl'],

'ACCORD/3-Data Sets - Analysis/3a-Analysis Data Sets/otherlabs.sas7bdat': ['fpg', 'alt', 'cpk',
'potassium', 'screat', 'gfr',
'ualb', 'ucreat', 'uacr'],
Expand All @@ -380,20 +380,20 @@ def load_accord(endpoint=None, features=None, location=''):

outcome_tbl = 'ACCORD/3-Data Sets - Analysis/3a-Analysis Data Sets/cvdoutcomes.sas7bdat'

outcomes, features = _load_generic_biolincc_dataset(outcome_tbl=outcome_tbl,
time_col=time,
outcomes, features = _load_generic_biolincc_dataset(outcome_tbl=outcome_tbl,
time_col=time,
event_col=event,
features=features,
id_col='MaskID',
location=location,
visit_col='Visit',
location=location,
visit_col='Visit',
baseline_visit=(b'BLR', b'S01'))
outcomes['event'] = 1-outcomes['event']
outcomes['event'] = 1-outcomes['event']
outcomes['time'] = outcomes['time']

outcomes = outcomes.loc[outcomes['time']>1.0]
features = features.loc[outcomes.index]

outcomes['time'] = outcomes['time']-1
outcomes['time'] = outcomes['time']-1

return outcomes, features
25 changes: 17 additions & 8 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def fit(self, features, outcomes, ret_trained_model=True):

for fold in range(self.cv_folds):

fold_outcomes = outcomes.loc[folds==fold, 'time']
fold_outcomes = outcomes.loc[folds==fold, 'time']

if fold_outcomes.min() > time_min: time_min = fold_outcomes.min()
if fold_outcomes.max() < time_max: time_max = fold_outcomes.max()
if fold_outcomes.min() > time_min: time_min = fold_outcomes.min()
if fold_outcomes.max() < time_max: time_max = fold_outcomes.max()

unique_times = unique_times[unique_times>=time_min]
unique_times = unique_times[unique_times<time_max]
Expand Down Expand Up @@ -120,13 +120,22 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):
self.random_seed = random_seed
self.cv_folds = cv_folds

self.treated_experiment = SurvivalRegressionCV(model=model, cv_folds=cv_folds, random_seed=random_seed, hyperparam_grid=hyperparam_grid)
self.control_experiment = SurvivalRegressionCV(model=model, cv_folds=cv_folds, random_seed=random_seed, hyperparam_grid=hyperparam_grid)
self.treated_experiment = SurvivalRegressionCV(model=model,
cv_folds=cv_folds,
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)

self.control_experiment = SurvivalRegressionCV(model=model,
cv_folds=cv_folds,
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)

def fit(self, features, outcomes, interventions):

treated, control = interventions==1, interventions!=1
treated_model = self.treated_experiment.fit(features.loc[treated], outcomes.loc[treated])
control_model = self.control_experiment.fit(features.loc[control], outcomes.loc[control])
treated_model = self.treated_experiment.fit(features.loc[treated],
outcomes.loc[treated])
control_model = self.control_experiment.fit(features.loc[control],
outcomes.loc[control])

return CounterfactualSurvivalModel(treated_model, control_model)
8 changes: 4 additions & 4 deletions auton_survival/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@ class Explainer:

def __init__(self):
return


class DecisionTreeExplainer(Explainer):

def __init__(self, random_seed=0, **kwargs):

self.random_seed = random_seed
self.random_seed = random_seed
self.fitted = False
self.kwargs = kwargs

def fit(self, features, phenotype):

model = DecisionTreeClassifier(random_state=self.random_seed, **self.kwargs)
self._model = model.fit(features, phenotype)
self.fitted = True
self.feature_names = features.columns
self.feature_names = features.columns
return self

def phenotype(self, features):
Expand Down
Loading

0 comments on commit 77a615a

Please sign in to comment.