Skip to content

Commit

Permalink
modified: cmhe_api.py
Browse files Browse the repository at this point in the history
	modified:   cmhe_utilities.py
  • Loading branch information
chiragnagpal committed Feb 13, 2022
1 parent 43d56ed commit caa4145
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
78 changes: 47 additions & 31 deletions auton_survival/models/cmhe/cmhe_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .cmhe_torch import DeepCoxPHTorch
from .cmhe_utilities import train_cmhe, predict_scores
from .cmhe_torch import CoxMixtureHeterogenousEffects
from .cmhe_torch import DeepCoxMixtureHeterogenousEffects
from .cmhe_utilities import train_cmhe, predict_survival

import torch
import numpy as np

class CoxMixturesHeterogenousEffects:
"""A Cox Mixtures with Heterogenous Effects model.
class DeepCoxMixturesHeterogenousEffects:
"""A Deep Cox Mixtures with Heterogenous Effects model.
This is the main interface to a Deep Cox Mixture model.
This is the main interface to a Deep Cox Mixture with Heterogenous Effects.
A model is instantiated with approporiate set of hyperparameters and
fit on numpy arrays consisting of the features, event/censoring times
and the event/censoring indicators.
Expand All @@ -21,15 +23,17 @@ class CoxMixturesHeterogenousEffects:
Parameters
----------
k: int
The number of underlying Cox distributions.
The number of underlying base survival phenotypes.
g: int
The number of underlying treatment effect phenotypes.
layers: list
A list of integers consisting of the number of neurons in each
hidden layer.
Example
-------
>>> from dsm.contrib import DeepCoxMixtures
>>> model = DeepCoxMixtures()
>>> model.fit(x, t, e)
>>> from auton_survival import CoxMixturesHeterogenousEffects
>>> model = CoxMixturesHeterogenousEffects()
>>> model.fit(x, t, e, a)
"""

Expand All @@ -40,53 +44,60 @@ def __init__(self, layers=None):

def __call__(self):
if self.fitted:
print("A fitted instance of the Deep Cox PH model")
print("A fitted instance of the CMHE model")
else:
print("An unfitted instance of the Deep Cox PH model")
print("An unfitted instance of the CMHE model")

print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x):
return torch.from_numpy(x).float()
def _preprocess_test_data(self, x, a):
return torch.from_numpy(x).float(), torch.from_numpy(a).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state):
def _preprocess_training_data(self, x, t, e, a, vsize, val_data, random_state):

idx = list(range(x.shape[0]))

np.random.seed(random_state)
np.random.shuffle(idx)

x_train, t_train, e_train = x[idx], t[idx], e[idx]
x_train, t_train, e_train, a_train = x[idx], t[idx], e[idx], a[idx]

x_train = torch.from_numpy(x_train).float()
t_train = torch.from_numpy(t_train).float()
e_train = torch.from_numpy(e_train).float()
a_train = torch.from_numpy(a_train).float()

if val_data is None:

vsize = int(vsize*x_train.shape[0])
x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]
x_val, t_val, e_val, a_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:], a_train[-vsize:]

x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]
a_train = a_train[:-vsize]

else:

x_val, t_val, e_val = val_data
x_val, t_val, e_val, a_val = val_data

x_val = torch.from_numpy(x_val).float()
t_val = torch.from_numpy(t_val).float()
e_val = torch.from_numpy(e_val).float()
a_val = torch.from_numpy(a_val).float()

return (x_train, t_train, e_train, x_val, t_val, e_val)
return (x_train, t_train, e_train, a_train,
x_val, t_val, e_val, a_val)

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
return DeepCoxPHTorch(inputdim, layers=self.layers,
optimizer=optimizer)
if len(self.layers):
return DeepCoxMixtureHeterogenousEffects(inputdim, layers=self.layers,
optimizer=optimizer)
else:
return CoxMixtureHeterogenousEffects(inputdim, optimizer=optimizer)

def fit(self, x, t, e, vsize=0.15, val_data=None,
def fit(self, x, t, e, a, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
optimizer="Adam", random_state=100):

Expand All @@ -101,6 +112,9 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
e: np.ndarray
A numpy array of the event/censoring indicators, \( \delta \).
\( \delta = 1 \) means the event took place.
a: np.ndarray
A numpy array of the treatment assignment indicators, \( a \).
\( a = 1 \) means the individual was treated.
vsize: float
Amount of data to set aside as the validation set.
val_data: tuple
Expand All @@ -119,21 +133,21 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
random seed that determines how the validation set is chosen.
"""

processed_data = self._preprocess_training_data(x, t, e,
processed_data = self._preprocess_training_data(x, t, e, a,
vsize, val_data,
random_state)

x_train, t_train, e_train, x_val, t_val, e_val = processed_data
x_train, t_train, e_train, a_train, x_val, t_val, e_val, a_val = processed_data

#Todo: Change this somehow. The base design shouldn't depend on child

inputdim = x_train.shape[-1]

model = self._gen_torch_model(inputdim, optimizer)

model, _ = train_dcph(model,
(x_train, t_train, e_train),
(x_val, t_val, e_val),
model, _ = train_cmhe(model,
(x_train, t_train, e_train, a_train),
(x_val, t_val, e_val, a_val),
epochs=iters,
lr=learning_rate,
bs=batch_size,
Expand All @@ -144,23 +158,25 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,

return self

def predict_risk(self, x, t=None):
def predict_risk(self, x, a, t=None):

if self.fitted:
return 1-self.predict_survival(x, t)
return 1-self.predict_survival(x, a, t)
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_risk`.")

def predict_survival(self, x, t=None):
def predict_survival(self, x, a, t=None):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
a: np.ndarray
A numpy array of the treatmeant assignment, \( a \).
t: list or float
a list or float of the times at which survival probability is
to be computed
Expand All @@ -173,12 +189,12 @@ def predict_survival(self, x, t=None):
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")

x = self._preprocess_test_data(x)
x = self._preprocess_test_data(x, a)

if t is not None:
if not isinstance(t, list):
t = [t]

scores = predict_survival(self.torch_model, x, t)
scores = predict_survival(self.torch_model, x, a, t)
return scores

2 changes: 1 addition & 1 deletion auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def train_cmhe(model, train_data, val_data, epochs=50,
if return_losses: return (model, breslow_splines), losses
else: return (model, breslow_splines)

def predict_scores(model, x, a, t):
def predict_survival(model, x, a, t):

if isinstance(t, (int, float)): t = [t]

Expand Down

0 comments on commit caa4145

Please sign in to comment.