diff --git a/auton_survival/__init__.py b/auton_survival/__init__.py index 84c4a74..49fce29 100644 --- a/auton_survival/__init__.py +++ b/auton_survival/__init__.py @@ -27,41 +27,13 @@ * There is presence of censoring ie. a large number of instances of data are lost to follow up. -Auton Survival ----------------- +The Auton Survival Package +--------------------------- -Repository of reusable code utilities for Survival Analysis projects. - -Dataset Loading and Preprocessing ---------------------------------- - -Helper functions to load various trial data like `TOPCAT`, `BARI2D` and `ALLHAT`. - -```python -# Load the TOPCAT Dataset -from auton_survival import dataset -features, outcomes = datasets.load_topcat() -``` - -### `auton_survival.preprocessing` -This module provides a flexible API to perform imputation and data -normalization for downstream machine learning models. The module has -3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor` -class is a composite transform that does both Imputing ***and*** Scaling. - -```python -# Preprocessing loaded Datasets -from auton_survival import datasets -features, outcomes = datasets.load_topcat() - -from auton_survival.preprocessing import Preprocessing -features = Preprocessor().fit_transform(features, - cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'], - num_feats=['height', 'weight']) - -# The `cat_feats` and `num_feats` lists would contain all the categorical and numerical features in the dataset. - -``` +The package `auton_survival` is repository of reusable utilities for projects +involving censored Time-to-Event Data. `auton_survival` allows rapid +experimentation including dataset preprocessing, regression, counterfactual +estimation, clustering and phenotyping and propnsity adjusted evaluation. Survival Regression @@ -77,33 +49,30 @@ class is a composite transform that does both Imputing ***and*** Scaling. Currently supported Survival Models are: -- Cox Proportional Hazards Model (`lifelines`): -- Random Survival Forests (`pysurvival`): -- Weibull Accelerated Failure Time (`lifelines`) : -- Deep Survival Machines: **Not Implemented Yet** -- Deep Cox Mixtures: **Not Implemented Yet** +- `auton_survival.models.dsm.DeepSurvivalMachines` +- `auton_survival.models.dcm.DeepCoxMixtures` +- `auton_survival.models.cph.DeepCoxPH` +`auton_survival` also provides convenient wrappers around other popular +python survival analysis packages to experiment with the following +survival regression estimators -```python -# Preprocessing loaded Datasets -from auton_survival import datasets -features, outcomes = datasets.load_topcat() - -from auton_survival.estimators import Preprocessing -features = Preprocessing().fit_transform(features) -``` +- Random Survival Forests (`pysurvival`): +- Weibull Accelerated Failure Time (`lifelines`) : ### `auton_survival.experiments` Modules to perform standard survival analysis experiments. This module -provides a top-level interface to run `auton_survival` Style experiments +provides a top-level interface to run `auton_survival` style experiments of survival analysis, involving cross-validation style experiments with -multiple different survival analysis models at different horizons of event times. +multiple different survival analysis models at different horizons of +event times. The module further eases evaluation by automatically computing the *censoring adjusted* estimates of the Metrics of interest, like -**Time Dependent Concordance Index** and **Brier Score** with **IPCW** adjustment. +**Time Dependent Concordance Index** and **Brier Score** with **IPCW** +adjustment. ```python # auton_survival Style Cross Validation Experiment. @@ -131,13 +100,68 @@ class is a composite transform that does both Imputing ***and*** Scaling. ### `auton_survival.phenotyping` +`auton_survival.phenotyping` allows extraction of latent clusters or subgroups +of patients that demonstrate similar outcomes. In the context of this package, +we refer to this task as **phenotyping**. `auton_survival.phenotyping` allows: + +- **Unsupervised Phenotyping**: Involves first performing dimensionality +reduction on the inpute covariates \( x \) followed by the use of a clustering +algorithm on this representation. + +- **Factual Phenotyping**: Involves the use of structured latent variable +models, `auton_survival.models.dcm.DeepCoxMixtures` or +`auton_survival.models.dsm.DeepSurvivalMachines` to recover phenogroups that +demonstrate differential observed survival rates. + +- **Counterfactual Phenotyping**: Involves learning phenotypes that demonstrate +heterogenous treatment effects. That is, the learnt phenogroups have differential +response to a specific intervention. Relies on the specially designed +`auton_survival.models.cmhe.DeepCoxMixturesHeterogenousEffects` latent variable model. + +Dataset Loading and Preprocessing +--------------------------------- + +Helper functions to load and prerocsss various time-to-event data like the +popular `SUPPORT`, `FRAMINGHAM` and `PBC` dataset for survival analysis. + + +### `auton_survival.datasets` + +```python +# Load the SUPPORT Dataset +from auton_survival import dataset +features, outcomes = datasets.load_dataset('SUPPORT') +``` + +### `auton_survival.preprocessing` +This module provides a flexible API to perform imputation and data +normalization for downstream machine learning models. The module has +3 distinct classes, `Scaler`, `Imputer` and `Preprocessor`. The `Preprocessor` +class is a composite transform that does both Imputing ***and*** Scaling with +a single function call. + +```python +# Preprocessing loaded Datasets +from auton_survival import datasets +features, outcomes = datasets.load_topcat() + +from auton_survival.preprocessing import Preprocessing +features = Preprocessor().fit_transform(features, + cat_feats=['GENDER', 'ETHNICITY', 'SMOKE'], + num_feats=['height', 'weight']) + +# The `cat_feats` and `num_feats` lists would contain all the categorical and +# numerical features in the dataset. + +``` + Reporting ---------- ### `auton_survival.reporting` -Helper functions to generate standard reports for popular Survival Analysis problems. +Helper functions to generate standard reports for common Survival Analysis tasks. ## Installation @@ -148,22 +172,22 @@ class is a composite transform that does both Imputing ***and*** Scaling. Compatibility ------------- -`dsm` requires `python` 3.5+ and `pytorch` 1.1+. +`auton_survival` requires `python` 3.5+ and `pytorch` 1.1+. To evaluate performance using standard metrics -`dsm` requires `scikit-survival`. +`auton_survival` requires `scikit-survival`. Contributing ------------ -`dsm` is [on GitHub]. Bug reports and pull requests are welcome. +`auton_survival` is [on GitHub]. Bug reports and pull requests are welcome. -[on GitHub]: https://github.com/chiragnagpal/deepsurvivalmachines +[on GitHub]: https://github.com/autonlab/auton-survival License ------- MIT License -Copyright (c) 2020 Carnegie Mellon University, [Auton Lab](http://autonlab.org) +Copyright (c) 2022 Carnegie Mellon University, [Auton Lab](http://autonlab.org) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -184,10 +208,10 @@ class is a composite transform that does both Imputing ***and*** Scaling. SOFTWARE. - - +









diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py index 0be441d..e432900 100644 --- a/auton_survival/experiments.py +++ b/auton_survival/experiments.py @@ -47,21 +47,21 @@ def fit(self, features, outcomes, ret_trained_model=True): best_model = {} best_score = np.inf - for hyper_param in tqdm(self.hyperparam_grid): + for hyper_param in tqdm(self.hyperparam_grid): predictions = np.zeros((len(features), len(unique_times))) fold_models = {} for fold in tqdm(range(self.cv_folds)): # Fit the model - fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param) + fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param) fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold]) - fold_models[fold] = fold_model + fold_models[fold] = fold_model # Predict risk scores predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold], times=unique_times) # Evaluate IBS - score_per_fold = [] + score_per_fold = [] for fold in range(self.cv_folds): score = survival_regression_metric('ibs', predictions, outcomes, unique_times, folds, fold) score_per_fold.append(score) @@ -70,8 +70,8 @@ def fit(self, features, outcomes, ret_trained_model=True): if current_score < best_score: best_score = current_score - best_model = fold_models - best_hyper_param = hyper_param + best_model = fold_models + best_hyper_param = hyper_param best_predictions = predictions self.best_hyperparameter = best_hyper_param diff --git a/auton_survival/metrics.py b/auton_survival/metrics.py index afae64e..baeace8 100644 --- a/auton_survival/metrics.py +++ b/auton_survival/metrics.py @@ -30,7 +30,7 @@ def survival_diff_metric(metric, outcomes, treatment_indicator, assigned treatment. weights : pd.Series Treatment assignment propensity scores, \( \widehat{\mathbb{P}}(A|X=x) \). - If None, all weights are set to 0.5. Default is None. + If `None`, all weights are set to \( 0.5 \). Default is `None`. horizon : float The time horizon at which to compare the survival curves. Must be specified for metric 'restricted_mean' and 'survival_at'. diff --git a/auton_survival/preprocessing.py b/auton_survival/preprocessing.py index 821917b..e792f93 100644 --- a/auton_survival/preprocessing.py +++ b/auton_survival/preprocessing.py @@ -10,24 +10,25 @@ class Imputer: - r"""Imputation is the first key aspect of the preprocessing workflow. - It replaces null values, allowing the machine learning process to continue. - This class includes separate implementations for categorical and - numerical/continuous features. + r"""A class to impute missing values in the input features. - For categorical features, the user can choose between the - following strategies: + Real world datasets are often subject to missing covariates. + Imputation replaces the missing values allowing downstream experiments. + This class allows multiple strategies to impute both categorical and + numerical/continuous covariates. - - **replace**: Replace all null values with a constant. - - **ignore**: Keep all null values - - **mode**: Replace null values with most commonly occurring category. + For categorical features, the class allows: + + - **replace**: Replace all null values with a user specificed constant. + - **ignore**: Keep all missing values as is. + - **mode**: Replace null values with most commonly occurring variable. For numerical/continuous features, the user can choose between the following strategies: - - **mean**: Replace all null values with the mean in the column. - - **median**: Replace all null values with the median in the column. - - **knn**: Use the KNN model to predict the null values. + - **mean**: Replace all missing values with the mean in the column. + - **median**: Replace all missing values with the median in the column. + - **knn**: Use a k Nearest Neighbour model to predict the missing value. - **missforest**: Use the MissForest model to predict the null values. Parameters @@ -67,7 +68,8 @@ def fit(self, data, cat_feats=None, num_feats=None, if cat_feats is None: cat_feats = [] if num_feats is None: num_feats = [] - assert len(cat_feats + num_feats) != 0, "Please specify categorical and numerical features." + assert (len(cat_feats + num_feats) != 0, + "Please specify categorical and numerical features.") self._cat_feats = cat_feats self._num_feats = num_feats @@ -83,9 +85,11 @@ def fit(self, data, cat_feats=None, num_feats=None, ####### CAT VARIABLES if len(cat_feats): if self.cat_feat_strat == 'replace': - self._cat_base_imputer = SimpleImputer(strategy='constant', fill_value=fill_value).fit(df[cat_feats]) + self._cat_base_imputer = SimpleImputer(strategy='constant', + fill_value=fill_value).fit(df[cat_feats]) elif self.cat_feat_strat == 'mode': - self._cat_base_imputer = SimpleImputer(strategy='most_frequent', fill_value=fill_value).fit(df[cat_feats]) + self._cat_base_imputer = SimpleImputer(strategy='most_frequent', + fill_value=fill_value).fit(df[cat_feats]) ####### NUM VARIABLES if len(num_feats): @@ -153,10 +157,6 @@ class Scaler: """Scaler to rescale numerical features. - Scaling is the second key aspect of the preprocessing workflow. - It transforms continuous values to improve the performance of the - machine learning algorithms. - For scaling, the user can choose between the following strategies: - **standard**: Perform the standard scaling method. @@ -167,7 +167,8 @@ class Scaler: ---------- scaling_strategy: str Strategy to use for scaling numerical/continuous data. - One of `'standard'`, `'minmax'`, `'none'`. Default is `standard`. + One of `'standard'`, `'minmax'`, `'none'`. + Default is `standard`. """ _VALID_SCALING_STRAT = ['standard', 'minmax', 'none'] @@ -186,8 +187,8 @@ def fit_transform(self, data, feats=[]): data: pandas.DataFrame Dataframe to be scaled. feats: list - List of numerical/continuous features to be scaled - if left empty, - all features are interpreted as numerical features. + List of numerical/continuous features to be scaled. + **NOTE**: if left empty, all features are interpreted as numerical. Returns: pandas.DataFrame: Scaled dataset. @@ -204,7 +205,7 @@ def fit_transform(self, data, feats=[]): else: scaler = None - if scaler != None: + if scaler is not None: if feats: df[feats] = scaler.fit_transform(df[feats]) else: df[df.columns] = scaler.fit_transform(df) @@ -212,7 +213,7 @@ def fit_transform(self, data, feats=[]): class Preprocessor: - """Class to perform full preprocessing pipeline. + """ A composite transform involving both scaling and preprocessing. Parameters ---------- @@ -237,7 +238,8 @@ def __init__(self, cat_feat_strat='ignore', self.scaler = Scaler(scaling_strategy=scaling_strategy) - def fit_transform(self, data, cat_feats, num_feats, one_hot=True, fill_value=-1, n_neighbors=5, **kwargs): + def fit_transform(self, data, cat_feats, num_feats, + one_hot=True, fill_value=-1, n_neighbors=5, **kwargs): """Imputes and scales dataset. Parameters