Skip to content

Commit

Permalink
CV updates
Browse files Browse the repository at this point in the history
  • Loading branch information
PotosnakW committed Apr 28, 2022
1 parent 8ae1ef4 commit cd5cb93
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 151 deletions.
482 changes: 367 additions & 115 deletions auton_survival/experiments.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
from .cmhe_utilities import train_cmhe, predict_survival
from .cmhe_utilities import predict_latent_phi, predict_latent_z

from auton_survival.preprocessing import _dataframe_to_array
from auton_survival.utils import _dataframe_to_array


class DeepCoxMixturesHeterogenousEffects:
Expand Down
2 changes: 1 addition & 1 deletion auton_survival/models/cph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .dcph_torch import DeepCoxPHTorch, DeepRecurrentCoxPHTorch
from .dcph_utilities import train_dcph, predict_survival

from auton_survival.preprocessing import _dataframe_to_array
from auton_survival.utils import _dataframe_to_array
from auton_survival.models.dsm.utilities import _get_padded_features
from auton_survival.models.dsm.utilities import _get_padded_targets

Expand Down
2 changes: 1 addition & 1 deletion auton_survival/models/dcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .dcm_torch import DeepCoxMixturesTorch
from .dcm_utilities import train_dcm, predict_survival, predict_latent_z

from auton_survival.preprocessing import _dataframe_to_array
from auton_survival.utils import _dataframe_to_array


class DeepCoxMixtures:
Expand Down
2 changes: 1 addition & 1 deletion auton_survival/models/dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
from .utilities import _get_padded_features, _get_padded_targets
from .utilities import _reshape_tensor_with_nans

from auton_survival.preprocessing import _dataframe_to_array
from auton_survival.utils import _dataframe_to_array


__pdoc__ = {}
Expand Down
10 changes: 8 additions & 2 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def __init__(self,

self.random_seed = random_seed

def fit(self, features, outcomes, interventions, horizon):
def fit(self, features, outcomes, interventions, metric,
horizon, cat_feats, num_feats):

"""Fit a counterfactual model and regress the difference of the estimated
counterfactual Restricted Mean Survival Time using a Random Forest regressor.
Expand All @@ -502,6 +503,10 @@ def fit(self, features, outcomes, interventions, horizon):
horizon : np.float
The event horizon at which to compute the counterfacutal RMST for
regression.
cat_feats: list
List of categorical features.
num_feats: list
List of numerical/continuous features.
Returns
-----------
Expand All @@ -512,7 +517,8 @@ def fit(self, features, outcomes, interventions, horizon):
cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method,
hyperparam_grid=self.cf_hyperparams)

self.cf_model = cf_model.fit(features, outcomes, interventions)
self.cf_model = cf_model.fit(features, outcomes, interventions,
metric, cat_feats, num_feats)

times = np.unique(outcomes.time.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features,
Expand Down
6 changes: 0 additions & 6 deletions auton_survival/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,3 @@ def fit_transform(self, data, cat_feats, num_feats,
output = pd.get_dummies(output, dummy_na=False, drop_first=True)

return output

def _dataframe_to_array(data):
if isinstance(data, (pd.Series, pd.DataFrame)):
return data.to_numpy()
else:
return data
7 changes: 7 additions & 0 deletions auton_survival/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import pandas as pd

def _get_method_kwargs(method, kwargs):

Expand All @@ -11,3 +12,9 @@ def _get_method_kwargs(method, kwargs):
method_kwargs = {k: kwargs[k] for k in method_params}

return method_kwargs

def _dataframe_to_array(data):
if isinstance(data, (pd.Series, pd.DataFrame)):
return data.to_numpy()
else:
return data
46 changes: 22 additions & 24 deletions examples/CV Survival Regression on SUPPORT Dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
"`auton-survival` offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of `auton-survival` to train survival models on the *SUPPORT* dataset in cross validation fashion."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -37,10 +32,8 @@
"\n",
"cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']\n",
"num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', \n",
"\t 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n",
" 'glucose', 'bun', 'urine', 'adlp', 'adls']\n",
"\n",
"features = Preprocessor().fit_transform(features, cat_feats=cat_feats, num_feats=num_feats)"
" 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n",
" 'glucose', 'bun', 'urine', 'adlp', 'adls']"
]
},
{
Expand All @@ -59,25 +52,23 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from auton_survival.experiments import SurvivalRegressionCV\n",
"\n",
"param_grid = {'k' : [3],\n",
" 'distribution' : ['Weibull'],\n",
" 'learning_rate' : [ 1e-4, 1e-3],\n",
" 'layers' : [[], [100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', cv_folds=5, hyperparam_grid=param_grid, random_seed=0)\n",
"model = experiment.fit(features, outcomes)\n"
]
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment.folds"
"from auton_survival.experiments import SurvivalRegressionCV\n",
"\n",
"param_grid = {'k' : [3],\n",
" 'distribution' : ['Weibull'],\n",
" 'learning_rate' : [1e-4, 1e-3],\n",
" 'layers' : [[100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n",
"model = experiment.fit(features, outcomes, cat_feats=cat_feats, num_feats=num_feats, one_hot=True)\n"
]
},
{
Expand All @@ -86,8 +77,7 @@
"metadata": {},
"outputs": [],
"source": [
"out_risk = model.predict_risk(features, times)\n",
"out_survival = model.predict_survival(features, times)"
"experiment.folds"
]
},
{
Expand All @@ -96,7 +86,15 @@
"metadata": {},
"outputs": [],
"source": [
"from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc"
"from auton_survival.preprocessing import Preprocessor\n",
"\n",
"preprocessor = Preprocessor(cat_feat_strat='replace', num_feat_strat='median',\n",
" scaling_strategy='standard', one_hot=True)\n",
"features_preprocessed = preprocessor.fit_transform(features, cat_feats=cat_feats, \n",
" num_feats=num_feats, fill_value=-1)\n",
"\n",
"out_risk = model.predict_risk(features_preprocessed, times)\n",
"out_survival = model.predict_survival(features_preprocessed, times)"
]
},
{
Expand Down

0 comments on commit cd5cb93

Please sign in to comment.