Skip to content

Commit

Permalink
modified: ../auton_survival/__init__.py
Browse files Browse the repository at this point in the history
	modified:   ../auton_survival/datasets.py
	new file:   ../auton_survival/datasets/synthetic_dataset.csv
	modified:   ../auton_survival/models/cmhe/__init__.py
	modified:   ../auton_survival/models/cmhe/cmhe_api.py
	modified:   ../auton_survival/models/cmhe/cmhe_torch.py
	modified:   ../auton_survival/models/cph/__init__.py
	modified:   ../auton_survival/models/cph/dcph_api.py
	modified:   ../auton_survival/models/dcm/__init__.py
  • Loading branch information
chiragnagpal committed Feb 13, 2022
1 parent 3ce2edf commit e3339dc
Show file tree
Hide file tree
Showing 11 changed files with 5,611 additions and 67 deletions.
50 changes: 34 additions & 16 deletions auton_survival/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,45 @@
## `auton_survival.datasets`
Helper functions to load various trial data like `TOPCAT`, `BARI2D` and `ALLHAT`.
```python
# Load the TOPCAT Dataset
from auton_survival import dataset
features, outcomes = datasets.load_topcat()
```
## `auton_survival.preprocessing`
This module provides a flexible API to perform imputation and data normalization for downstream machine learning models. The module has 3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor` class is a composite transform that does both Imputing ***and*** Scaling.
This module provides a flexible API to perform imputation and data
normalization for downstream machine learning models. The module has
3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor`
class is a composite transform that does both Imputing ***and*** Scaling.
```python
# Preprocessing loaded Datasets
from auton_survival import datasets
features, outcomes = datasets.load_topcat()
from auton_survival.preprocessing import Preprocessing
features = Preprocessor().fit_transform(features, cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'], num_feats=['height', 'weight'])
features = Preprocessor().fit_transform(features,
cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'],
num_feats=['height', 'weight'])
# The `cat_feats` and `num_feats` lists would contain all the categorical and numerical features in the dataset.
```
## `auton_survival.estimators`
This module provids a wrapper to model BioLINNC datasets with standard survival (time-to-event) analysis methods.
The use of the wrapper allows a simple standard interface for multiple different survival models, and also helps standardize experiments across various differents research areas.
This module provids a wrapper to model BioLINNC datasets with standard
survival (time-to-event) analysis methods.
The use of the wrapper allows a simple standard interface for multiple different
survival models, and also helps standardize experiments across various differents
research areas.
Currently supported Survival Models are:
Currently supported Survival Models are:
- Cox Proportional Hazards Model (`lifelines`):
- Random Survival Forests (`pysurvival`):
- Weibull Accelerated Failure Time (`lifelines`) :
- Cox Proportional Hazards Model (`lifelines`):
- Random Survival Forests (`pysurvival`):
- Weibull Accelerated Failure Time (`lifelines`) :
- Deep Survival Machines: **Not Implemented Yet**
- Deep Cox Mixtures: **Not Implemented Yet**
Expand All @@ -51,11 +59,16 @@
```
## `auton_survival.experiments`
## `auton_survival.experiments`
Modules to perform standard survival analysis experiments. This module provides a top-level interface to run `auton_survival` Style experiments of survival analysis, involving cross-validation style experiments with multiple different survival analysis models at different horizons of event times.
Modules to perform standard survival analysis experiments. This module
provides a top-level interface to run `auton_survival` Style experiments
of survival analysis, involving cross-validation style experiments with
multiple different survival analysis models at different horizons of event times.
The module further eases evaluation by automatically computing the *censoring adjusted* estimates of the Metrics of interest, like **Time Dependent Concordance Index** and **Brier Score** with **IPCW** adjustement.
The module further eases evaluation by automatically computing the
*censoring adjusted* estimates of the Metrics of interest, like
**Time Dependent Concordance Index** and **Brier Score** with **IPCW** adjustment.
```python
# auton_survival Style Cross Validation Experiment.
Expand All @@ -64,7 +77,7 @@
from auton_survival.experiments import SurvivalCVRegressionExperiment
# instantiate an auton_survival Experiment by
# instantiate an auton_survival Experiment by
# specifying the features and outcomes to use.
experiment = SurvivalCVRegressionExperiment(features, outcomes)
Expand All @@ -82,16 +95,21 @@
## `auton_survival.reporting`
Helper functions to generate standard reports for popular Survival Analysis problems.
Helper functions to generate standard reports for popular Survival Analysis problems.
## Installation
```console
foo@bar:~$ git clone https://github.com/autonlab/auton_survival
foo@bar:~$ pip install -r requirements.txt
foo@bar:~$ pip install -r requirements.txt
```
## Requirements
`scikit-learn`, `scikit-survival`, `lifelines`, `matplotlib`, `pandas`, `numpy`, `missingpy`
'''
'''

from .models.dsm import DeepSurvivalMachines
from .models.dcm import DeepCoxMixtures
from .models.cph import DeepCoxPH, DeepRecurrentCoxPH
from .models.cmhe import DeepCoxMixturesHeterogenousEffects
30 changes: 27 additions & 3 deletions auton_survival/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def load_support():
outcomes['event'] = data['death']
outcomes['time'] = data['d.time']
outcomes = outcomes[['event', 'time']]

cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
Expand Down Expand Up @@ -240,6 +240,19 @@ def _load_mnist():

return x, t, e

def load_synthetic_cf_phenotyping():

data = pkgutil.get_data(__name__, 'datasets/synthetic_dataset.csv')
data = pd.read_csv(io.BytesIO(data))

outcomes = data[['event', 'time', 'uncensored time treated',
'uncensored time control', 'Z','Zeta']]

features = data[['X1','X2','X3','X4','X5','X6','X7','X8']]
interventions = data['intervention']

return outcomes, features, interventions

def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Currently implemented datasets include:
Expand All @@ -254,6 +267,11 @@ def load_dataset(dataset='SUPPORT', **kwargs):
known, ongoing Framingham Heart study [3] for studying epidemiology for
hypertensive and arteriosclerotic cardiovascular disease. It is a popular
dataset for longitudinal survival analysis with time dependent covariates.
References
**SYNTHETIC**: This is a non-linear censored dataset for counterfactual
time-to-event phenotyping. Introduced in [4], the dataset is generated
such that the treatment effect is heterogenous conditioned on the covariates.
References
-----------
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
Expand All @@ -265,12 +283,16 @@ def load_dataset(dataset='SUPPORT', **kwargs):
"Epidemiological approaches to heart disease: the Framingham Study."
American Journal of Public Health and the Nations Health 41.3 (1951).
Parameters
[4] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski.
"Counterfactual phenotyping for censored Time-to-Events" (2022).
----------
dataset: str
The choice of dataset to load. Currently implemented is 'SUPPORT',
'PBC' and 'FRAMINGHAM'.
**kwargs: dict
Dataset specific keyword arguments.
Returns
----------
tuple: (np.ndarray, np.ndarray, np.ndarray)
Expand All @@ -280,12 +302,14 @@ def load_dataset(dataset='SUPPORT', **kwargs):
sequential = kwargs.get('sequential', False)

if dataset == 'SUPPORT':
return _load_support_dataset()
return load_support()
if dataset == 'PBC':
return _load_pbc_dataset(sequential)
if dataset == 'FRAMINGHAM':
return _load_framingham_dataset(sequential)
if dataset == 'MNIST':
return _load_mnist()
if dataset == 'SYNTHETIC':
return load_synthetic_cf_phenotyping()
else:
raise NotImplementedError('Dataset '+dataset+' not implemented.')
raise NotImplementedError('Dataset '+dataset+' not implemented.')
Loading

0 comments on commit e3339dc

Please sign in to comment.