diff --git a/auton_survival/models/cmhe/__init__.py b/auton_survival/models/cmhe/__init__.py new file mode 100644 index 0000000..c31434f --- /dev/null +++ b/auton_survival/models/cmhe/__init__.py @@ -0,0 +1 @@ +from .cmhe_api import DeepCoxMixturesHE \ No newline at end of file diff --git a/auton_survival/models/cmhe/cmhe_api.py b/auton_survival/models/cmhe/cmhe_api.py index f7400ee..0d943f4 100644 --- a/auton_survival/models/cmhe/cmhe_api.py +++ b/auton_survival/models/cmhe/cmhe_api.py @@ -1,11 +1,12 @@ -from .cmhe_torch import CoxMixtureHeterogenousEffects -from .cmhe_torch import DeepCoxMixtureHeterogenousEffects +from .cmhe_torch import CoxMixtureHETorch +from .cmhe_torch import DeepCoxMixtureHETorch from .cmhe_utilities import train_cmhe, predict_survival +from .cmhe_utilities import predict_latent_phi, predict_latent_z import torch import numpy as np -class DeepCoxMixturesHeterogenousEffects: +class DeepCoxMixturesHE: """A Deep Cox Mixtures with Heterogenous Effects model. This is the main interface to a Deep Cox Mixture with Heterogenous Effects. @@ -17,8 +18,8 @@ class DeepCoxMixturesHeterogenousEffects: References ---------- - [1] <a href="https://arxiv.org/abs/2101.06536">Deep Cox Mixtures - for Survival Regression. Machine Learning in Health Conference (2021)</a> + [1] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski. + "Counterfactual phenotyping for censored Time-to-Events" (2022). Parameters ---------- @@ -29,6 +30,7 @@ class DeepCoxMixturesHeterogenousEffects: layers: list A list of integers consisting of the number of neurons in each hidden layer. + Example ------- >>> from auton_survival import CoxMixturesHeterogenousEffects @@ -37,11 +39,13 @@ class DeepCoxMixturesHeterogenousEffects: """ - def __init__(self, layers=None): + def __init__(self, k, g, layers=None): self.layers = layers self.fitted = False - + self.k = k + self.g = g + def __call__(self): if self.fitted: print("A fitted instance of the CMHE model") @@ -50,8 +54,11 @@ def __call__(self): print("Hidden Layers:", self.layers) - def _preprocess_test_data(self, x, a): - return torch.from_numpy(x).float(), torch.from_numpy(a).float() + def _preprocess_test_data(self, x, a=None): + if a is not None: + return torch.from_numpy(x).float(), torch.from_numpy(a).float() + else: + return torch.from_numpy(x).float() def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state): @@ -60,42 +67,44 @@ def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state): np.random.seed(random_state) np.random.shuffle(idx) - x_train, t_train, e_train, a_train = x[idx], t[idx], e[idx], a[idx] + x_tr, t_tr, e_tr, a_tr = x[idx], t[idx], e[idx], a[idx] - x_train = torch.from_numpy(x_train).float() - t_train = torch.from_numpy(t_train).float() - e_train = torch.from_numpy(e_train).float() - a_train = torch.from_numpy(a_train).float() + x_tr = torch.from_numpy(x_tr).float() + t_tr = torch.from_numpy(t_tr).float() + e_tr = torch.from_numpy(e_tr).float() + a_tr = torch.from_numpy(a_tr).float() if val_data is None: - vsize = int(vsize*x_train.shape[0]) - x_val, t_val, e_val, a_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:], a_train[-vsize:] + vsize = int(vsize*x_tr.shape[0]) + x_vl, t_vl, e_vl, a_vl = x_tr[-vsize:], t_tr[-vsize:], e_tr[-vsize:], a_tr[-vsize:] - x_train = x_train[:-vsize] - t_train = t_train[:-vsize] - e_train = e_train[:-vsize] - a_train = a_train[:-vsize] + x_tr = x_tr[:-vsize] + t_tr = t_tr[:-vsize] + e_tr = e_tr[:-vsize] + a_tr = a_tr[:-vsize] else: - x_val, t_val, e_val, a_val = val_data + x_vl, t_vl, e_vl, a_vl = val_data - x_val = torch.from_numpy(x_val).float() - t_val = torch.from_numpy(t_val).float() - e_val = torch.from_numpy(e_val).float() - a_val = torch.from_numpy(a_val).float() + x_vl = torch.from_numpy(x_vl).float() + t_vl = torch.from_numpy(t_vl).float() + e_vl = torch.from_numpy(e_vl).float() + a_vl = torch.from_numpy(a_vl).float() - return (x_train, t_train, e_train, a_train, - x_val, t_val, e_val, a_val) + return (x_tr, t_tr, e_tr, a_tr, + x_vl, t_vl, e_vl, a_vl) def _gen_torch_model(self, inputdim, optimizer): """Helper function to return a torch model.""" if len(self.layers): - return DeepCoxMixtureHeterogenousEffects(inputdim, layers=self.layers, - optimizer=optimizer) + return DeepCoxMixtureHETorch(inputdim, self.k, self.g, + layers=self.layers, + optimizer=optimizer) else: - return CoxMixtureHeterogenousEffects(inputdim, optimizer=optimizer) + return CoxMixtureHETorch(inputdim, self.k, self.g, + optimizer=optimizer) def fit(self, x, t, e, a, vsize=0.15, val_data=None, iters=1, learning_rate=1e-3, batch_size=100, @@ -137,17 +146,17 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None, vsize, val_data, random_state) - x_train, t_train, e_train, a_train, x_val, t_val, e_val, a_val = processed_data + x_tr, t_tr, e_tr, a_tr, x_vl, t_vl, e_vl, a_vl = processed_data #Todo: Change this somehow. The base design shouldn't depend on child - inputdim = x_train.shape[-1] + inputdim = x_tr.shape[-1] model = self._gen_torch_model(inputdim, optimizer) model, _ = train_cmhe(model, - (x_train, t_train, e_train, a_train), - (x_val, t_val, e_val, a_val), + (x_tr, t_tr, e_tr, a_tr), + (x_vl, t_vl, e_vl, a_vl), epochs=iters, lr=learning_rate, bs=batch_size, @@ -198,3 +207,30 @@ def predict_survival(self, x, a, t=None): scores = predict_survival(self.torch_model, x, a, t) return scores + def predict_latent_z(self, x): + + """Returns the estimated latent base survival group \( z \) given the confounders \( x \).""" + + x = self._preprocess_test_data(x) + + if self.fitted: + scores = predict_latent_z(self.torch_model, x) + return scores + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_latent_z`.") + + def predict_latent_phi(self, x): + + """Returns the estimated latent treatment effect group \( \phi \) given the confounders \( x \).""" + + x = self._preprocess_test_data(x) + + if self.fitted: + scores = predict_latent_phi(self.torch_model, x) + return scores + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_latent_phi`.") \ No newline at end of file diff --git a/auton_survival/models/cmhe/cmhe_torch.py b/auton_survival/models/cmhe/cmhe_torch.py index 891f24b..3dcc725 100644 --- a/auton_survival/models/cmhe/cmhe_torch.py +++ b/auton_survival/models/cmhe/cmhe_torch.py @@ -1,6 +1,6 @@ import torch -class CoxMixtureHeterogenousEffects(torch.nn.Module): +class CoxMixtureHETorch(torch.nn.Module): """PyTorch model definition of the Cox Mixture with Hereogenous Effects Model. Cox Mixtures with Heterogenous Effects involves the assuming that the @@ -14,7 +14,7 @@ class CoxMixtureHeterogenousEffects(torch.nn.Module): def __init__(self, k, g, inputdim, hidden=None): - super(CoxMixtureHeterogenousEffects, self).__init__() + super(CoxMixtureHETorch, self).__init__() assert isinstance(k, int) @@ -57,13 +57,11 @@ def forward(self, x, a): return logp_jointlatent_gate, logp_joint_hrs -class DeepCoxMixtureHeterogenousEffects(CoxMixtureHeterogenousEffects): +class DeepCoxMixtureHETorch(CoxMixtureHETorch): def __init__(self, k, g, inputdim, hidden): - super(DeepCoxMixtureHeterogenousEffects, self).__init__(k, g, - inputdim, - hidden) + super(DeepCoxMixtureHETorch, self).__init__(k, g, inputdim, hidden) # Get rich feature representations of the covariates self.embedding = torch.nn.Sequential(torch.nn.Linear(inputdim, hidden), diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py index 687fa28..530b40c 100644 --- a/auton_survival/models/cmhe/cmhe_utilities.py +++ b/auton_survival/models/cmhe/cmhe_utilities.py @@ -315,3 +315,21 @@ def predict_survival(model, x, a, t): predictions.append((gates*expert_outputs).sum(axis=1)) return np.array(predictions).T + +def predict_latent_z(model, x): + + model, _ = model + gates = model.model.embedding(x) + + z_gate_probs = torch.exp(gates).sum(axis=2).detach().numpy() + + return z_gate_probs + +def predict_latent_phi(model, x): + + model, _ = model + gates, _ = model(x) + + phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy() + + return phi_gate_probs