Skip to content

Commit

Permalink
modified: experiments.py
Browse files Browse the repository at this point in the history
	modified:   models/cmhe/__init__.py
	modified:   models/cmhe/cmhe_torch.py
	modified:   models/dcm/__init__.py
	modified:   phenotyping.py
	modified:   reporting.py
	new file:   ../examples/.ipynb_checkpoints/Demo of CMHE on Synthetic Data-checkpoint.ipynb
	modified:   ../examples/Demo of CMHE on Synthetic Data.ipynb
	new file:   ../examples/Phenotyping Censored Time-to-Events.ipynb
	new file:   ../examples/Survival Regression with Auton-Survival.ipynb
	new file:   ../examples/__pycache__/cmhe_demo_utils.cpython-38.pyc
	new file:   ../examples/__pycache__/estimators_demo_utils.cpython-38.pyc
	modified:   ../examples/cmhe_demo_utils.py
	new file:   ../examples/scratch.ipynb
  • Loading branch information
chiragnagpal committed Apr 2, 2022
1 parent e347556 commit 5087c89
Show file tree
Hide file tree
Showing 16 changed files with 3,506 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ persistent=no
load-plugins=

# Use multiple processes to speed up Pylint.
jobs=4
jobs=16

# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
Expand Down
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python.linting.pylintEnabled": true,
"python.linting.enabled": true
}
7 changes: 4 additions & 3 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

def fit(self, features, outcomes, ret_trained_model=True):

r"""Fits the Survival Regression Model to the data in a Cross Validation fashion.
r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
Parameters
-----------
Expand All @@ -76,8 +77,8 @@ def fit(self, features, outcomes, ret_trained_model=True):
a column named 'event' that contains the censoring status.
\( \delta_i = 1 \) if the event is observed.
ret_trained_model : bool
If True, the trained model is returned. If False, the fit function returns
self.
If True, the trained model is returned. If False, the fit function
returns self.
Returns
-----------
Expand Down
21 changes: 17 additions & 4 deletions auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,17 @@ class DeepCoxMixturesHeterogenousEffects:
"""

def __init__(self, k, g, layers=None):
def __init__(self, k, g, layers=None, gamma=100,
smoothing_factor=1e-4,
random_seed=0):

self.layers = layers
self.fitted = False
self.k = k
self.g = g
self.layers = layers
self.fitted = False
self.gamma = gamma
self.smoothing_factor = smoothing_factor
self.random_seed = random_seed

def __call__(self):
if self.fitted:
Expand Down Expand Up @@ -176,8 +181,14 @@ def _preprocess_training_data(self, x, t, e, a, vsize, val_data,

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""

np.random.seed(self.random_seed)
torch.manual_seed(self.random_seed)

return DeepCMHETorch(self.k, self.g, inputdim,
layers=self.layers,
gamma=self.gamma,
smoothing_factor=self.smoothing_factor,
optimizer=optimizer)

def fit(self, x, t, e, a, vsize=0.15, val_data=None,
Expand Down Expand Up @@ -235,7 +246,9 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None,
lr=learning_rate,
bs=batch_size,
patience=patience,
return_losses=True)
return_losses=True,
use_posteriors=True,
random_seed=self.random_seed)

self.torch_model = (model[0].eval(), model[1])
self.fitted = True
Expand Down
22 changes: 8 additions & 14 deletions auton_survival/models/cmhe/cmhe_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def _init_dcmhe_layers(self, lastdim):
self.phi_gate = torch.nn.Linear(lastdim, self.g, bias=False)
self.omega = torch.nn.Parameter(torch.rand(self.g)-0.5)

def __init__(self, k, g, inputdim, layers=None, optimizer='Adam'):
def __init__(self, k, g, inputdim, layers=None, gamma=100,
smoothing_factor=1e-4, optimizer='Adam'):

super(DeepCMHETorch, self).__init__()

Expand All @@ -56,6 +57,9 @@ def __init__(self, k, g, inputdim, layers=None, optimizer='Adam'):
self.k = k # Base Physiology groups
self.g = g # Treatment Effect groups

self.gamma = gamma
self.smoothing_factor = smoothing_factor

if len(layers) == 0: lastdim = inputdim
else: lastdim = layers[-1]

Expand All @@ -69,7 +73,9 @@ def forward(self, x, a):
x = self.embedding(x)
a = 2*(a-0.5)

log_hrs = torch.clamp(self.expert(x), min=-100, max=100)
log_hrs = torch.clamp(self.expert(x),
min=-self.gamma,
max=self.gamma)

logp_z_gate = torch.nn.LogSoftmax(dim=1)(self.z_gate(x)) #
logp_phi_gate = torch.nn.LogSoftmax(dim=1)(self.phi_gate(x))
Expand All @@ -87,15 +93,3 @@ def forward(self, x, a):
logp_joint_hrs[:, i, j] = log_hrs[:, i] + (j!=2)*a*self.omega[j]

return logp_jointlatent_gate, logp_joint_hrs

# class DeepCoxMixtureHETorch(CoxMixtureHETorch):

# def __init__(self, k, g, inputdim, hidden):

# super(DeepCoxMixtureHETorch, self).__init__(k, g, inputdim, hidden)

# # Get rich feature representations of the covariates
# self.embedding = torch.nn.Sequential(torch.nn.Linear(inputdim, hidden),
# torch.nn.Tanh(),
# torch.nn.Linear(hidden, hidden),
# torch.nn.Tanh())
6 changes: 6 additions & 0 deletions auton_survival/models/dcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class DeepCoxMixtures:
layers: list
A list of integers consisting of the number of neurons in each
hidden layer.
<<<<<<< Updated upstream
=======
random_seed: int
Controls the reproducibility of called functions.
>>>>>>> Stashed changes
Example
-------
>>> from auton_survival.models.dcm import DeepCoxMixtures
Expand Down
10 changes: 6 additions & 4 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ def fit(self, features):
"""

if self.dim_red_method is not None:
print("Fitting the following Dimensionality Reduction Model:\n", self.dim_red_model)
if self.dim_red_method is not None:
print("Fitting the following Dimensionality Reduction Model:\n",
self.dim_red_model)
self.dim_red_model = self.dim_red_model.fit(features)
features = self.dim_red_model.transform(features)

Expand Down Expand Up @@ -457,8 +458,9 @@ def fit(self, features, outcomes, interventions, horizon):
self.cf_model = cf_model.fit(features, outcomes, interventions)

times = np.unique(outcomes.times.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features, interventions, times)

cf_predictions = self.cf_model.predict_counterfactual_survival(features,
times)

ite_estimates = cf_predictions[1] - cf_predictions[0]

if self.phenotyping_method == 'rsf':
Expand Down
10 changes: 4 additions & 6 deletions auton_survival/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from lifelines import KaplanMeierFitter, NelsonAalenFitter

from collections import Counter

from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts


Expand Down Expand Up @@ -75,10 +73,10 @@ def plot_nelsonaalen(outcomes, groups=None, **kwargs):
for group in sorted(set(groups)):
if pd.isna(group): continue

print("Group:", group)
print('Group:', group)

NelsonAalenFitter().fit(outcomes[groups==group]['time'],
outcomes[groups==group]['event']).plot(label=group,
NelsonAalenFitter().fit(outcomes[groups==group]['time'],
outcomes[groups==group]['event']).plot(label=group,
**kwargs)


Loading

0 comments on commit 5087c89

Please sign in to comment.