Skip to content

Commit

Permalink
Rolled back SurvivlRegressionCV class. Removed support for nested CV.
Browse files Browse the repository at this point in the history
	modified:   experiments.py
	modified:   metrics.py
  • Loading branch information
chiragnagpal committed Jun 11, 2022
1 parent d38621e commit 304d615
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 169 deletions.
228 changes: 60 additions & 168 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

"""Utilities to perform cross-validation."""

from matplotlib.pyplot import hot
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -100,7 +101,9 @@ def __init__(self, model='dcph', folds=None, num_folds=5,
self.random_seed = random_seed
self.hyperparam_grid = list(ParameterGrid(hyperparam_grid))

def fit(self, features, outcomes, metric='ibs', horizon=None):
assert len(self.hyperparam_grid), "Cross Validation Grid is empty."

def fit(self, features, outcomes, horizons, metric='ibs'):

r"""Fits the survival regression model to the data in a cross-
validation or nested cross-validation fashion.
Expand All @@ -113,190 +116,79 @@ def fit(self, features, outcomes, metric='ibs', horizon=None):
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
horizon : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
horizon : int or float, default=None
Event-horizon at which to evaluate model performance.
If None, then the maximum permissible event-time from the data is used.
Returns
-----------
Trained survival regression model(s).
"""

self.metric = metric
self.horizon = horizon

if (horizon is None) & (metric not in ['auc', 'ctd', 'brs']):
warnings.warn("Horizon is not specificed for {} metric, so the maximum \
permissible time horizon from the data is used to evaluate model \
performance".format(metric))

if self.folds is None:
self.folds = self._get_stratified_folds(outcomes, 'event',
self.num_folds,
self.random_seed)

if self.num_nested_folds is None:
best_params = self._cv_select_parameters(features, outcomes, self.folds)
model = SurvivalModel(self.model, self.random_seed, **best_params)

return model.fit(features, outcomes)

else:
return self._train_nested_cv_models(features, outcomes)

def _train_nested_cv_models(self, features, outcomes):

"""Train models in a nested CV fashion.
Parameters
-----------
features : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
Returns
-----------
Trained survival regression models.
"""

models = {}
for fi, fold in enumerate(set(self.folds)):
x_tr = features.copy().loc[self.folds!=fold]
y_tr = outcomes.loc[self.folds!=fold]

self.nested_folds = self._get_stratified_folds(y_tr, 'event',
self.num_nested_folds,
self.random_seed)

best_params = self._cv_select_parameters(x_tr, y_tr, self.nested_folds)
model = SurvivalModel(self.model, self.random_seed, **best_params)
models[fi] = model.fit(x_tr, y_tr)

return models

def _cv_select_parameters(self, features, outcomes, folds):

"""Evaluate model performance on validation set in a CV fashion and
select hyperparameters.
Parameters
-----------
features : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively.
folds : list, default=None
A list of fold assignment values for each sample.
Returns
-----------
Dictionary of the best model hyperparameters based on the user-specified
metric.
"""

unique_times = np.unique(outcomes.time.values)
if self.horizon is not None:
unique_times = unique_times[unique_times<self.horizon]
unique_times = list(unique_times)+[self.horizon]
unique_times = np.array(sorted(unique_times))
self.times = self._check_times(outcomes, unique_times, folds)

if (self.horizon is not None) & (self.horizon not in self.times):
warnings.warn("Specified horizon is not permissible based on \
training and validation set times. %0.2f is used as the closest permissible \
event-horizon" %(max(self.times)))

fold_results = pd.DataFrame()
for fold in set(folds):
x_tr = features.copy().loc[folds!=fold]
x_val = features.copy().loc[folds==fold]
y_tr = outcomes.loc[folds!=fold]
y_val = outcomes.loc[folds==fold]

param_results = self._fit_evaluate_model(x_tr, y_tr, x_val, y_val)

# Hyperparameter results as row items and fold results as columns.
fold_results = pd.concat([fold_results, pd.DataFrame(param_results)],
axis=1)

agg_params_results = fold_results.mean(axis=1)
if (self.metric == 'ibs') | (self.metric == 'brs'):
best_params_idx = agg_params_results.reset_index(drop=True).idxmin()
elif (self.metric == 'auc') | (self.metric == 'ctd'):
best_params_idx = agg_params_results.reset_index(drop=True).idxmax()

return self.hyperparam_grid[best_params_idx]

def _fit_evaluate_model(self, features_tr, outcomes_tr,
features_val, outcomes_val):

"""Train the model and evaluate model performance on validation set.
assert horizons is not None, "Horizons must be specified."
if isinstance(horizons, (int, float)):
horizons = [horizons]

Parameters
-----------
features_tr : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates for the training set data.
outcomes_tr : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively
for the training set data.
features_val : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates for the validation set data.
outcomes_val : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the
survival time and censoring status \( \delta_i = 1 \), respectively
for the validation set data.
Returns
-----------
Model performance results in terms of the user-specific metric for each
hyperparameter combination.
"""

# Cannot compute metrics for evaluation set samples with time > follow-up time
max_follow_up = outcomes_tr.time.max()
val_sample_idx = outcomes_val.time.values < max_follow_up
outcomes_val = outcomes_val.loc[val_sample_idx]

if self.metric == 'ibs':
times = self.times
else:
times = [self.times[-1]]

param_results = []
for hyper_param in tqdm(self.hyperparam_grid):
model = SurvivalModel(self.model, self.random_seed, **hyper_param)
model.fit(features_tr, outcomes_tr)

predictions_val = model.predict_survival(features_val, times)
predictions_val = predictions_val[val_sample_idx]
self.metric = metric
self.horizons = horizons

metric_val = survival_regression_metric(metric=self.metric,
outcomes=outcomes_val,
predictions=predictions_val,
outcomes_train=outcomes_tr,
times=times)
param_results.append(metric_val)
if self.folds is None:
self.folds = self._get_stratified_folds(outcomes,
'event',
self.num_folds,
self.random_seed)
# Set the time horizon boundaries to be within all folds.
time_max, time_min = outcomes.time.max(), outcomes.time.min()
for fold in set(self.folds):
fold_time_max = outcomes.loc[self.folds==fold].time.max()
fold_time_min = outcomes.loc[self.folds==fold].time.min()

if fold_time_max < time_max: time_max = fold_time_max
if fold_time_min > time_min: time_min = fold_time_min

assert max(horizons) < time_max, "Horizons exceeds max time range."
assert min(horizons) > time_min, "Horizons exceeds min time range."

# if self.horizon is None:
# assert (self.metric == 'ibs'), "Horizon must be specified for the selected metric"
# self.horizon = time_max

hyper_param_scores = []
for i, hyper_param in enumerate(self.hyperparam_grid):
print("At hyper-param", hyper_param)

fold_scores = []
for fold in set(self.folds):
print("At fold:", fold)
model = SurvivalModel(self.model, random_seed=self.random_seed, **hyper_param)
model.fit(features.loc[self.folds!=fold], outcomes.loc[self.folds!=fold])
predictions = model.predict_survival(features.loc[self.folds==fold], times=horizons)

score = survival_regression_metric(metric=self.metric,
outcomes=outcomes.loc[self.folds==fold],
predictions=predictions,
times=horizons,
outcomes_train=outcomes.loc[self.folds!=fold])
fold_scores.append(score)
hyper_param_scores.append(np.mean(fold_scores))

if self.metric in ['ibs', 'brs']:
best_hyper_param = self.hyperparam_grid[np.argmin(hyper_param_scores)]
elif self.metric in ['auc', 'ctd']:
best_hyper_param = self.hyperparam_grid[np.argmax(hyper_param_scores)]

model = SurvivalModel(self.model,
random_seed=self.random_seed,
**best_hyper_param).fit(features, outcomes)
return model

return param_results

def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed):

Expand Down
2 changes: 1 addition & 1 deletion auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def survival_regression_metric(metric, outcomes, predictions,

assert max(times) < outcomes_train.time.max(), "Times should \
be within the range of event times to avoid exterpolation."
assert max(times) < outcomes.time.max(), "Times \
assert max(times) <= outcomes.time.max(), "Times \
must be within the range of event times."

survival_train = util.Surv.from_dataframe('event', 'time', outcomes_train)
Expand Down

0 comments on commit 304d615

Please sign in to comment.