Skip to content

Commit

Permalink
new file: __init__.py
Browse files Browse the repository at this point in the history
	modified:   cmhe_api.py
	modified:   cmhe_torch.py
	modified:   cmhe_utilities.py
  • Loading branch information
chiragnagpal committed Feb 13, 2022
1 parent caa4145 commit 3ce2edf
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 40 deletions.
1 change: 1 addition & 0 deletions auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cmhe_api import DeepCoxMixturesHE
104 changes: 70 additions & 34 deletions auton_survival/models/cmhe/cmhe_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .cmhe_torch import CoxMixtureHeterogenousEffects
from .cmhe_torch import DeepCoxMixtureHeterogenousEffects
from .cmhe_torch import CoxMixtureHETorch
from .cmhe_torch import DeepCoxMixtureHETorch
from .cmhe_utilities import train_cmhe, predict_survival
from .cmhe_utilities import predict_latent_phi, predict_latent_z

import torch
import numpy as np

class DeepCoxMixturesHeterogenousEffects:
class DeepCoxMixturesHE:
"""A Deep Cox Mixtures with Heterogenous Effects model.
This is the main interface to a Deep Cox Mixture with Heterogenous Effects.
Expand All @@ -17,8 +18,8 @@ class DeepCoxMixturesHeterogenousEffects:
References
----------
[1] <a href="https://arxiv.org/abs/2101.06536">Deep Cox Mixtures
for Survival Regression. Machine Learning in Health Conference (2021)</a>
[1] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski.
"Counterfactual phenotyping for censored Time-to-Events" (2022).
Parameters
----------
Expand All @@ -29,6 +30,7 @@ class DeepCoxMixturesHeterogenousEffects:
layers: list
A list of integers consisting of the number of neurons in each
hidden layer.
Example
-------
>>> from auton_survival import CoxMixturesHeterogenousEffects
Expand All @@ -37,11 +39,13 @@ class DeepCoxMixturesHeterogenousEffects:
"""

def __init__(self, layers=None):
def __init__(self, k, g, layers=None):

self.layers = layers
self.fitted = False

self.k = k
self.g = g

def __call__(self):
if self.fitted:
print("A fitted instance of the CMHE model")
Expand All @@ -50,8 +54,11 @@ def __call__(self):

print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x, a):
return torch.from_numpy(x).float(), torch.from_numpy(a).float()
def _preprocess_test_data(self, x, a=None):
if a is not None:
return torch.from_numpy(x).float(), torch.from_numpy(a).float()
else:
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state):

Expand All @@ -60,42 +67,44 @@ def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state):
np.random.seed(random_state)
np.random.shuffle(idx)

x_train, t_train, e_train, a_train = x[idx], t[idx], e[idx], a[idx]
x_tr, t_tr, e_tr, a_tr = x[idx], t[idx], e[idx], a[idx]

x_train = torch.from_numpy(x_train).float()
t_train = torch.from_numpy(t_train).float()
e_train = torch.from_numpy(e_train).float()
a_train = torch.from_numpy(a_train).float()
x_tr = torch.from_numpy(x_tr).float()
t_tr = torch.from_numpy(t_tr).float()
e_tr = torch.from_numpy(e_tr).float()
a_tr = torch.from_numpy(a_tr).float()

if val_data is None:

vsize = int(vsize*x_train.shape[0])
x_val, t_val, e_val, a_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:], a_train[-vsize:]
vsize = int(vsize*x_tr.shape[0])
x_vl, t_vl, e_vl, a_vl = x_tr[-vsize:], t_tr[-vsize:], e_tr[-vsize:], a_tr[-vsize:]

x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]
a_train = a_train[:-vsize]
x_tr = x_tr[:-vsize]
t_tr = t_tr[:-vsize]
e_tr = e_tr[:-vsize]
a_tr = a_tr[:-vsize]

else:

x_val, t_val, e_val, a_val = val_data
x_vl, t_vl, e_vl, a_vl = val_data

x_val = torch.from_numpy(x_val).float()
t_val = torch.from_numpy(t_val).float()
e_val = torch.from_numpy(e_val).float()
a_val = torch.from_numpy(a_val).float()
x_vl = torch.from_numpy(x_vl).float()
t_vl = torch.from_numpy(t_vl).float()
e_vl = torch.from_numpy(e_vl).float()
a_vl = torch.from_numpy(a_vl).float()

return (x_train, t_train, e_train, a_train,
x_val, t_val, e_val, a_val)
return (x_tr, t_tr, e_tr, a_tr,
x_vl, t_vl, e_vl, a_vl)

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
if len(self.layers):
return DeepCoxMixtureHeterogenousEffects(inputdim, layers=self.layers,
optimizer=optimizer)
return DeepCoxMixtureHETorch(inputdim, self.k, self.g,
layers=self.layers,
optimizer=optimizer)
else:
return CoxMixtureHeterogenousEffects(inputdim, optimizer=optimizer)
return CoxMixtureHETorch(inputdim, self.k, self.g,
optimizer=optimizer)

def fit(self, x, t, e, a, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
Expand Down Expand Up @@ -137,17 +146,17 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None,
vsize, val_data,
random_state)

x_train, t_train, e_train, a_train, x_val, t_val, e_val, a_val = processed_data
x_tr, t_tr, e_tr, a_tr, x_vl, t_vl, e_vl, a_vl = processed_data

#Todo: Change this somehow. The base design shouldn't depend on child

inputdim = x_train.shape[-1]
inputdim = x_tr.shape[-1]

model = self._gen_torch_model(inputdim, optimizer)

model, _ = train_cmhe(model,
(x_train, t_train, e_train, a_train),
(x_val, t_val, e_val, a_val),
(x_tr, t_tr, e_tr, a_tr),
(x_vl, t_vl, e_vl, a_vl),
epochs=iters,
lr=learning_rate,
bs=batch_size,
Expand Down Expand Up @@ -198,3 +207,30 @@ def predict_survival(self, x, a, t=None):
scores = predict_survival(self.torch_model, x, a, t)
return scores

def predict_latent_z(self, x):

"""Returns the estimated latent base survival group \( z \) given the confounders \( x \)."""

x = self._preprocess_test_data(x)

if self.fitted:
scores = predict_latent_z(self.torch_model, x)
return scores
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_latent_z`.")

def predict_latent_phi(self, x):

"""Returns the estimated latent treatment effect group \( \phi \) given the confounders \( x \)."""

x = self._preprocess_test_data(x)

if self.fitted:
scores = predict_latent_phi(self.torch_model, x)
return scores
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_latent_phi`.")
10 changes: 4 additions & 6 deletions auton_survival/models/cmhe/cmhe_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

class CoxMixtureHeterogenousEffects(torch.nn.Module):
class CoxMixtureHETorch(torch.nn.Module):
"""PyTorch model definition of the Cox Mixture with Hereogenous Effects Model.
Cox Mixtures with Heterogenous Effects involves the assuming that the
Expand All @@ -14,7 +14,7 @@ class CoxMixtureHeterogenousEffects(torch.nn.Module):

def __init__(self, k, g, inputdim, hidden=None):

super(CoxMixtureHeterogenousEffects, self).__init__()
super(CoxMixtureHETorch, self).__init__()

assert isinstance(k, int)

Expand Down Expand Up @@ -57,13 +57,11 @@ def forward(self, x, a):

return logp_jointlatent_gate, logp_joint_hrs

class DeepCoxMixtureHeterogenousEffects(CoxMixtureHeterogenousEffects):
class DeepCoxMixtureHETorch(CoxMixtureHETorch):

def __init__(self, k, g, inputdim, hidden):

super(DeepCoxMixtureHeterogenousEffects, self).__init__(k, g,
inputdim,
hidden)
super(DeepCoxMixtureHETorch, self).__init__(k, g, inputdim, hidden)

# Get rich feature representations of the covariates
self.embedding = torch.nn.Sequential(torch.nn.Linear(inputdim, hidden),
Expand Down
18 changes: 18 additions & 0 deletions auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,21 @@ def predict_survival(model, x, a, t):

predictions.append((gates*expert_outputs).sum(axis=1))
return np.array(predictions).T

def predict_latent_z(model, x):

model, _ = model
gates = model.model.embedding(x)

z_gate_probs = torch.exp(gates).sum(axis=2).detach().numpy()

return z_gate_probs

def predict_latent_phi(model, x):

model, _ = model
gates, _ = model(x)

phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy()

return phi_gate_probs

0 comments on commit 3ce2edf

Please sign in to comment.