Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
PotosnakW committed Apr 18, 2022
1 parent 0d199c5 commit ff77c7f
Show file tree
Hide file tree
Showing 7 changed files with 1,015 additions and 523 deletions.
109 changes: 2 additions & 107 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,62 +118,6 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):

return model

# if len(layers): model = DeepCoxMixture(k=k, inputdim=features.shape[1], hidden=layers[0])
# else: model = CoxMixture(k=k, inputdim=features.shape[1])

# x = torch.from_numpy(features.values.astype('float32'))
# t = torch.from_numpy(outcomes['time'].values.astype('float32'))
# e = torch.from_numpy(outcomes['event'].values.astype('float32'))

# vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)

# train_data = (x[~vidx], t[~vidx], e[~vidx])
# val_data = (x[vidx], t[vidx], e[vidx])

# (model, breslow_splines, unique_times) = train(model,
# train_data,
# val_data,
# epochs=epochs,
# lr=lr, bs=bs,
# use_posteriors=True,
# patience=5,
# return_losses=False,
# smoothing_factor=smoothing_factor)

#return (model, breslow_splines, unique_times)

# THIS IS 1 OF 2 _PREDICT_DCM FUNCTIONS HERE BUT THIS ONE THROWS A BUG SO I USE _PREDICT_DCM FUNCTION BELOW
# def _predict_dcm(model, features, times):

# """Predict survival probabilities at specified time(s) using the
# Deep Cox Mixtures model.

# Parameters
# -----------
# model : Trained instance of the Deep Cox Mixtures model.
# features : pd.DataFrame
# A pandas dataframe with rows corresponding to individual
# samples and columns as covariates.
# times: float or list
# A float or list of the times at which to compute
# the survival probability.

# Returns
# -----------
# np.array : An array of the survival probabilites at each
# time point in times.

# """

# #raise NotImplementedError()

# survival_predictions = model.predict_survival(features, times)
# if len(times)>1:
# survival_predictions = pd.DataFrame(survival_predictions, columns=times).T
# return __interpolate_missing_times(survival_predictions, times)
# else:
# return survival_predictions

def _fit_dcph(features, outcomes, random_seed, **hyperparams):

"""Fit a Deep Cox Proportional Hazards Model/Farragi Simon Network [1,2]
Expand Down Expand Up @@ -228,55 +172,6 @@ def _fit_dcph(features, outcomes, random_seed, **hyperparams):

return model

#raise NotImplementedError()
# import torch
# import torchtuples as ttup

# from pycox.models import CoxPH

# torch.manual_seed(random_seed)
# np.random.seed(random_seed)

# layers = hyperparams.get('layers', [100])
# lr = hyperparams.get('lr', 1e-3)
# bs = hyperparams.get('bs', 100)
# epochs = hyperparams.get('epochs', 50)
# activation = hyperparams.get('activation', 'relu')

# if activation == 'relu': activation = torch.nn.ReLU
# elif activation == 'relu6': activation = torch.nn.ReLU6
# elif activation == 'tanh': activation = torch.nn.Tanh
# else: raise NotImplementedError("Activation function not implemented")

# x = features.values.astype('float32')
# t = outcomes['time'].values.astype('float32')
# e = outcomes['event'].values.astype('bool')

# in_features = x.shape[1]
# out_features = 1
# batch_norm = False
# dropout = 0.0

# net = ttup.practical.MLPVanilla(in_features, layers,
# out_features, batch_norm, dropout,
# activation=activation,
# output_bias=False)

# model = CoxPH(net, torch.optim.Adam)

# vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)

# y_train, y_val = (t[~vidx], e[~vidx]), (t[vidx], e[vidx])
# val_data = x[vidx], y_val

# callbacks = [ttup.callbacks.EarlyStopping()]
# model.fit(x[~vidx], y_train, bs, epochs, callbacks, True,
# val_data=val_data,
# val_batch_size=bs)
# model.compute_baseline_hazards()

# return model

def __interpolate_missing_times(survival_predictions, times):
"""Interpolate survival probabilities at missing time points.
Expand Down Expand Up @@ -771,14 +666,14 @@ def __init__(self, treated_model, control_model):

def predict_counterfactual_survival(self, features, times):

control_outcomes = self.control_model.predict_survival(features, times)
treated_outcomes = self.treated_model.predict_survival(features, times)
control_outcomes = self.control_model.predict_survival(features, times)

return treated_outcomes, control_outcomes

def predict_counterfactual_risk(self, features, times):

control_outcomes = self.control_model.predict_risk(features, times)
treated_outcomes = self.treated_model.predict_risk(features, times)
control_outcomes = self.control_model.predict_risk(features, times)

return treated_outcomes, control_outcomes
80 changes: 59 additions & 21 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# coding=utf-8
# MIT License

# Copyright (c) 2022 Carnegie Mellon University, Auton Lab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Utilities to perform cross-validation."""

from copy import deepcopy
import numpy as np

from auton_survival.estimators import SurvivalModel, CounterfactualSurvivalModel
Expand Down Expand Up @@ -94,17 +120,23 @@ def fit(self, features, outcomes, ret_trained_model=True):

self.folds = folds

unique_times = np.unique(outcomes['time'].values)

time_min, time_max = unique_times.min(), unique_times.max()

unique_times = np.unique(outcomes.time.values)
time_max, time_min = unique_times.max(), unique_times.min()

for fold in range(self.cv_folds):

fold_outcomes = outcomes.loc[folds==fold, 'time']

if fold_outcomes.min() > time_min: time_min = fold_outcomes.min()
if fold_outcomes.max() < time_max: time_max = fold_outcomes.max()

time_test = outcomes.loc[folds==fold, 'time']
time_train = outcomes.loc[folds!=fold, 'time']

if time_test.min() > time_min:
time_min = time_test.min()

if (time_test.max() < time_max)|(time_train.max() < time_max):
if time_test.max() > time_train.max():
time_max = max(time_test[time_test < time_train.max()])
else:
time_max = max(time_test[time_test < time_test.max()])

unique_times = unique_times[unique_times>=time_min]
unique_times = unique_times[unique_times<time_max]

Expand All @@ -119,24 +151,30 @@ def fit(self, features, outcomes, ret_trained_model=True):

fold_models = {}
for fold in tqdm(range(self.cv_folds)):

# Fit the model
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold])
fold_models[fold] = fold_model

# Predict risk scores
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold],
times=unique_times)
times=unique_times.tolist())

score_per_fold = []
for fold in range(self.cv_folds):
for fold in range(self.cv_folds):
outcomes_train = outcomes.loc[folds!=fold]
outcomes_test = outcomes.loc[folds==fold]
predictions_test = predictions[folds==fold]
outcomes_test = outcomes.loc[folds==fold].copy()
predictions_test = deepcopy(predictions[folds==fold])

# Cannot compute IBS for test set samples with time > follow-up time
max_follow_up = outcomes_train.time.max()
predictions_test = predictions_test[outcomes_test.time.values < max_follow_up]
outcomes_test = outcomes_test.loc[outcomes_test.time.values < max_follow_up]

# Compute IBS
score = survival_regression_metric('ibs', outcomes_train, outcomes_test,
predictions_test, unique_times)
score = survival_regression_metric('ibs', outcomes_train, predictions_test,
unique_times, outcomes_test)
score_per_fold.append(score)

current_score = np.mean(score_per_fold)
Expand Down Expand Up @@ -235,6 +273,8 @@ class CounterfactualSurvivalRegressionCV:
"""

_VALID_CF_METHODS = ['dsm', 'dcph', 'dcm', 'rsf', 'cph']

def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

self.model = model
Expand Down Expand Up @@ -277,11 +317,9 @@ def fit(self, features, outcomes, interventions):
"""


treated, control = interventions==1, interventions!=1
treated_model = self.treated_experiment.fit(features.loc[treated],
outcomes.loc[treated])
control_model = self.control_experiment.fit(features.loc[control],
outcomes.loc[control])
treated_model = self.treated_experiment.fit(features.loc[interventions==1],
outcomes.loc[interventions==1])
control_model = self.control_experiment.fit(features.loc[interventions!=1],
outcomes.loc[interventions!=1])

return CounterfactualSurvivalModel(treated_model, control_model)
Loading

0 comments on commit ff77c7f

Please sign in to comment.