diff --git a/auton_survival/__init__.py b/auton_survival/__init__.py
index 84c4a74..49fce29 100644
--- a/auton_survival/__init__.py
+++ b/auton_survival/__init__.py
@@ -27,41 +27,13 @@
* There is presence of censoring ie. a large number of instances of data are
lost to follow up.
-Auton Survival
-----------------
+The Auton Survival Package
+---------------------------
-Repository of reusable code utilities for Survival Analysis projects.
-
-Dataset Loading and Preprocessing
----------------------------------
-
-Helper functions to load various trial data like `TOPCAT`, `BARI2D` and `ALLHAT`.
-
-```python
-# Load the TOPCAT Dataset
-from auton_survival import dataset
-features, outcomes = datasets.load_topcat()
-```
-
-### `auton_survival.preprocessing`
-This module provides a flexible API to perform imputation and data
-normalization for downstream machine learning models. The module has
-3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor`
-class is a composite transform that does both Imputing ***and*** Scaling.
-
-```python
-# Preprocessing loaded Datasets
-from auton_survival import datasets
-features, outcomes = datasets.load_topcat()
-
-from auton_survival.preprocessing import Preprocessing
-features = Preprocessor().fit_transform(features,
- cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'],
- num_feats=['height', 'weight'])
-
-# The `cat_feats` and `num_feats` lists would contain all the categorical and numerical features in the dataset.
-
-```
+The package `auton_survival` is repository of reusable utilities for projects
+involving censored Time-to-Event Data. `auton_survival` allows rapid
+experimentation including dataset preprocessing, regression, counterfactual
+estimation, clustering and phenotyping and propnsity adjusted evaluation.
Survival Regression
@@ -77,33 +49,30 @@ class is a composite transform that does both Imputing ***and*** Scaling.
Currently supported Survival Models are:
-- Cox Proportional Hazards Model (`lifelines`):
-- Random Survival Forests (`pysurvival`):
-- Weibull Accelerated Failure Time (`lifelines`) :
-- Deep Survival Machines: **Not Implemented Yet**
-- Deep Cox Mixtures: **Not Implemented Yet**
+- `auton_survival.models.dsm.DeepSurvivalMachines`
+- `auton_survival.models.dcm.DeepCoxMixtures`
+- `auton_survival.models.cph.DeepCoxPH`
+`auton_survival` also provides convenient wrappers around other popular
+python survival analysis packages to experiment with the following
+survival regression estimators
-```python
-# Preprocessing loaded Datasets
-from auton_survival import datasets
-features, outcomes = datasets.load_topcat()
-
-from auton_survival.estimators import Preprocessing
-features = Preprocessing().fit_transform(features)
-```
+- Random Survival Forests (`pysurvival`):
+- Weibull Accelerated Failure Time (`lifelines`) :
### `auton_survival.experiments`
Modules to perform standard survival analysis experiments. This module
-provides a top-level interface to run `auton_survival` Style experiments
+provides a top-level interface to run `auton_survival` style experiments
of survival analysis, involving cross-validation style experiments with
-multiple different survival analysis models at different horizons of event times.
+multiple different survival analysis models at different horizons of
+event times.
The module further eases evaluation by automatically computing the
*censoring adjusted* estimates of the Metrics of interest, like
-**Time Dependent Concordance Index** and **Brier Score** with **IPCW** adjustment.
+**Time Dependent Concordance Index** and **Brier Score** with **IPCW**
+adjustment.
```python
# auton_survival Style Cross Validation Experiment.
@@ -131,13 +100,68 @@ class is a composite transform that does both Imputing ***and*** Scaling.
### `auton_survival.phenotyping`
+`auton_survival.phenotyping` allows extraction of latent clusters or subgroups
+of patients that demonstrate similar outcomes. In the context of this package,
+we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows:
+
+- **Unsupervised Phenotyping**: Involves first performing dimensionality
+reduction on the inpute covariates \( x \) followed by the use of a clustering
+algorithm on this representation.
+
+- **Factual Phenotyping**: Involves the use of structured latent variable
+models, `auton_survival.models.dcm.DeepCoxMixtures` or
+`auton_survival.models.dsm.DeepSurvivalMachines` to recover phenogroups that
+demonstrate differential observed survival rates.
+
+- **Counterfactual Phenotyping**: Involves learning phenotypes that demonstrate
+heterogenous treatment effects. That is, the learnt phenogroups have differential
+response to a specific intervention. Relies on the specially designed
+`auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model.
+
+Dataset Loading and Preprocessing
+---------------------------------
+
+Helper functions to load and prerocsss various time-to-event data like the
+popular `SUPPORT`, `FRAMINGHAM` and `PBC` dataset for survival analysis.
+
+
+### `auton_survival.datasets`
+
+```python
+# Load the SUPPORT Dataset
+from auton_survival import dataset
+features, outcomes = datasets.load_dataset('SUPPORT')
+```
+
+### `auton_survival.preprocessing`
+This module provides a flexible API to perform imputation and data
+normalization for downstream machine learning models. The module has
+3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor`
+class is a composite transform that does both Imputing ***and*** Scaling with
+a single function call.
+
+```python
+# Preprocessing loaded Datasets
+from auton_survival import datasets
+features, outcomes = datasets.load_topcat()
+
+from auton_survival.preprocessing import Preprocessing
+features = Preprocessor().fit_transform(features,
+ cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'],
+ num_feats=['height', 'weight'])
+
+# The `cat_feats` and `num_feats` lists would contain all the categorical and
+# numerical features in the dataset.
+
+```
+
Reporting
----------
### `auton_survival.reporting`
-Helper functions to generate standard reports for popular Survival Analysis problems.
+Helper functions to generate standard reports for common Survival Analysis tasks.
## Installation
@@ -148,22 +172,22 @@ class is a composite transform that does both Imputing ***and*** Scaling.
Compatibility
-------------
-`dsm` requires `python` 3.5+ and `pytorch` 1.1+.
+`auton_survival` requires `python` 3.5+ and `pytorch` 1.1+.
To evaluate performance using standard metrics
-`dsm` requires `scikit-survival`.
+`auton_survival` requires `scikit-survival`.
Contributing
------------
-`dsm` is [on GitHub]. Bug reports and pull requests are welcome.
+`auton_survival` is [on GitHub]. Bug reports and pull requests are welcome.
-[on GitHub]: https://github.com/chiragnagpal/deepsurvivalmachines
+[on GitHub]: https://github.com/autonlab/auton-survival
License
-------
MIT License
-Copyright (c) 2020 Carnegie Mellon University, [Auton Lab](http://autonlab.org)
+Copyright (c) 2022 Carnegie Mellon University, [Auton Lab](http://autonlab.org)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@@ -184,10 +208,10 @@ class is a composite transform that does both Imputing ***and*** Scaling.
SOFTWARE.
-
-
+
diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py
index 0be441d..e432900 100644
--- a/auton_survival/experiments.py
+++ b/auton_survival/experiments.py
@@ -47,21 +47,21 @@ def fit(self, features, outcomes, ret_trained_model=True):
best_model = {}
best_score = np.inf
- for hyper_param in tqdm(self.hyperparam_grid):
+ for hyper_param in tqdm(self.hyperparam_grid):
predictions = np.zeros((len(features), len(unique_times)))
fold_models = {}
for fold in tqdm(range(self.cv_folds)):
# Fit the model
- fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
+ fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold])
- fold_models[fold] = fold_model
+ fold_models[fold] = fold_model
# Predict risk scores
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold], times=unique_times)
# Evaluate IBS
- score_per_fold = []
+ score_per_fold = []
for fold in range(self.cv_folds):
score = survival_regression_metric('ibs', predictions, outcomes, unique_times, folds, fold)
score_per_fold.append(score)
@@ -70,8 +70,8 @@ def fit(self, features, outcomes, ret_trained_model=True):
if current_score < best_score:
best_score = current_score
- best_model = fold_models
- best_hyper_param = hyper_param
+ best_model = fold_models
+ best_hyper_param = hyper_param
best_predictions = predictions
self.best_hyperparameter = best_hyper_param
diff --git a/auton_survival/metrics.py b/auton_survival/metrics.py
index afae64e..baeace8 100644
--- a/auton_survival/metrics.py
+++ b/auton_survival/metrics.py
@@ -30,7 +30,7 @@ def survival_diff_metric(metric, outcomes, treatment_indicator,
assigned treatment.
weights : pd.Series
Treatment assignment propensity scores, \( \widehat{\mathbb{P}}(A|X=x) \).
- If None, all weights are set to 0.5. Default is None.
+ If `None`, all weights are set to \( 0.5 \). Default is `None`.
horizon : float
The time horizon at which to compare the survival curves.
Must be specified for metric 'restricted_mean' and 'survival_at'.
diff --git a/auton_survival/preprocessing.py b/auton_survival/preprocessing.py
index 821917b..e792f93 100644
--- a/auton_survival/preprocessing.py
+++ b/auton_survival/preprocessing.py
@@ -10,24 +10,25 @@
class Imputer:
- r"""Imputation is the first key aspect of the preprocessing workflow.
- It replaces null values, allowing the machine learning process to continue.
- This class includes separate implementations for categorical and
- numerical/continuous features.
+ r"""A class to impute missing values in the input features.
- For categorical features, the user can choose between the
- following strategies:
+ Real world datasets are often subject to missing covariates.
+ Imputation replaces the missing values allowing downstream experiments.
+ This class allows multiple strategies to impute both categorical and
+ numerical/continuous covariates.
- - **replace**: Replace all null values with a constant.
- - **ignore**: Keep all null values
- - **mode**: Replace null values with most commonly occurring category.
+ For categorical features, the class allows:
+
+ - **replace**: Replace all null values with a user specificed constant.
+ - **ignore**: Keep all missing values as is.
+ - **mode**: Replace null values with most commonly occurring variable.
For numerical/continuous features,
the user can choose between the following strategies:
- - **mean**: Replace all null values with the mean in the column.
- - **median**: Replace all null values with the median in the column.
- - **knn**: Use the KNN model to predict the null values.
+ - **mean**: Replace all missing values with the mean in the column.
+ - **median**: Replace all missing values with the median in the column.
+ - **knn**: Use a k Nearest Neighbour model to predict the missing value.
- **missforest**: Use the MissForest model to predict the null values.
Parameters
@@ -67,7 +68,8 @@ def fit(self, data, cat_feats=None, num_feats=None,
if cat_feats is None: cat_feats = []
if num_feats is None: num_feats = []
- assert len(cat_feats + num_feats) != 0, "Please specify categorical and numerical features."
+ assert (len(cat_feats + num_feats) != 0,
+ "Please specify categorical and numerical features.")
self._cat_feats = cat_feats
self._num_feats = num_feats
@@ -83,9 +85,11 @@ def fit(self, data, cat_feats=None, num_feats=None,
####### CAT VARIABLES
if len(cat_feats):
if self.cat_feat_strat == 'replace':
- self._cat_base_imputer = SimpleImputer(strategy='constant', fill_value=fill_value).fit(df[cat_feats])
+ self._cat_base_imputer = SimpleImputer(strategy='constant',
+ fill_value=fill_value).fit(df[cat_feats])
elif self.cat_feat_strat == 'mode':
- self._cat_base_imputer = SimpleImputer(strategy='most_frequent', fill_value=fill_value).fit(df[cat_feats])
+ self._cat_base_imputer = SimpleImputer(strategy='most_frequent',
+ fill_value=fill_value).fit(df[cat_feats])
####### NUM VARIABLES
if len(num_feats):
@@ -153,10 +157,6 @@ class Scaler:
"""Scaler to rescale numerical features.
- Scaling is the second key aspect of the preprocessing workflow.
- It transforms continuous values to improve the performance of the
- machine learning algorithms.
-
For scaling, the user can choose between the following strategies:
- **standard**: Perform the standard scaling method.
@@ -167,7 +167,8 @@ class Scaler:
----------
scaling_strategy: str
Strategy to use for scaling numerical/continuous data.
- One of `'standard'`, `'minmax'`, `'none'`. Default is `standard`.
+ One of `'standard'`, `'minmax'`, `'none'`.
+ Default is `standard`.
"""
_VALID_SCALING_STRAT = ['standard', 'minmax', 'none']
@@ -186,8 +187,8 @@ def fit_transform(self, data, feats=[]):
data: pandas.DataFrame
Dataframe to be scaled.
feats: list
- List of numerical/continuous features to be scaled - if left empty,
- all features are interpreted as numerical features.
+ List of numerical/continuous features to be scaled.
+ **NOTE**: if left empty, all features are interpreted as numerical.
Returns:
pandas.DataFrame: Scaled dataset.
@@ -204,7 +205,7 @@ def fit_transform(self, data, feats=[]):
else:
scaler = None
- if scaler != None:
+ if scaler is not None:
if feats: df[feats] = scaler.fit_transform(df[feats])
else: df[df.columns] = scaler.fit_transform(df)
@@ -212,7 +213,7 @@ def fit_transform(self, data, feats=[]):
class Preprocessor:
- """Class to perform full preprocessing pipeline.
+ """ A composite transform involving both scaling and preprocessing.
Parameters
----------
@@ -237,7 +238,8 @@ def __init__(self, cat_feat_strat='ignore',
self.scaler = Scaler(scaling_strategy=scaling_strategy)
- def fit_transform(self, data, cat_feats, num_feats, one_hot=True, fill_value=-1, n_neighbors=5, **kwargs):
+ def fit_transform(self, data, cat_feats, num_feats,
+ one_hot=True, fill_value=-1, n_neighbors=5, **kwargs):
"""Imputes and scales dataset.
Parameters