Skip to content

Commit

Permalink
modified: dsm_api.py
Browse files Browse the repository at this point in the history
	modified:   losses.py
  • Loading branch information
chiragnagpal committed Apr 13, 2021
1 parent c454774 commit 8dd7c7b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
28 changes: 28 additions & 0 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,34 @@ def predict_survival(self, x, t, risk=1):
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")

def predict_pdf(self, x, t, risk=1):
r"""Returns the estimated pdf at time \( t \),
\( \widehat{\mathbb{P}}(T = t|X) \) for some input data \( x \).
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
t: list or float
a list or float of the times at which pdf is
to be computed
Returns:
np.array: numpy array of the estimated pdf at each time in t.
"""
x = self._prepocess_test_data(x)
if not isinstance(t, list):
t = [t]
if self.fitted:
scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk))
return np.exp(np.array(scores)).T
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")




class DeepSurvivalMachines(DSMBase):
"""A Deep Survival Machines model.
Expand Down
77 changes: 77 additions & 0 deletions dsm/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ def conditional_loss(model, x, t, e, elbo=True, risk='1'):
raise NotImplementedError('Distribution: '+model.dist+
' not implemented yet.')

def _weibull_pdf(model, x, t_horizon, risk='1'):

squish = nn.LogSoftmax(dim=1)

shape, scale, logits = model.forward(x, risk)
logits = squish(logits)

k_ = shape
b_ = scale

t_horz = torch.tensor(t_horizon).double()
t_horz = t_horz.repeat(shape.shape[0], 1)

pdfs = []
for j in range(len(t_horizon)):

t = t_horz[:, j]
lpdfs = []

for g in range(model.k):

k = k_[:, g]
b = b_[:, g]
s = - (torch.pow(torch.exp(b)*t, torch.exp(k)))
f = k + b + ((torch.exp(k)-1)*(b+torch.log(t)))
f = f + s
lpdfs.append(f)

lpdfs = torch.stack(lpdfs, dim=1)
lpdfs = lpdfs+logits
lpdfs = torch.logsumexp(lpdfs, dim=1)
pdfs.append(lpdfs.detach().numpy())

return lpdfs

def _weibull_cdf(model, x, t_horizon, risk='1'):

Expand Down Expand Up @@ -327,6 +361,35 @@ def _weibull_cdf(model, x, t_horizon, risk='1'):

return cdfs

def _weibull_mean(model, x, risk='1'):

squish = nn.LogSoftmax(dim=1)

shape, scale, logits = model.forward(x, risk)
logits = squish(logits)

k_ = shape
b_ = scale

lmeans = []

for g in range(model.k):

k = k_[:, g]
b = b_[:, g]

one_over_k = torch.reciprocal(torch.exp(k))
lmean = -(one_over_k*b) + torch.lgamma(1+one_over_k)
lmeans.append(lmean)

lmeans = torch.stack(lmeans, dim=1)
lmeans = lmeans+logits
lmeans = torch.logsumexp(lmeans, dim=1)

return torch.exp(lmeans).detach().numpy()




def _lognormal_cdf(model, x, t_horizon, risk='1'):

Expand Down Expand Up @@ -424,15 +487,29 @@ def _normal_mean(model, x, risk='1'):

return lmeans.detach().numpy()


def predict_mean(model, x, risk='1'):
torch.no_grad()
if model.dist == 'Normal':
return _normal_mean(model, x, risk)
elif model.dist == 'Weibull':
return _weibull_mean(model, x, risk)
else:
raise NotImplementedError('Mean of Distribution: '+model.dist+
' not implemented yet.')


def predict_pdf(model, x, t_horizon, risk='1'):
torch.no_grad()
if model.dist == 'Weibull':
return _weibull_pdf(model, x, t_horizon, risk)
# if model.dist == 'LogNormal':
# return _lognormal_pdf(model, x, t_horizon, risk)
# if model.dist == 'Normal':
# return _normal_pdf(model, x, t_horizon, risk)
else:
raise NotImplementedError('Distribution: '+model.dist+
' not implemented yet.')


def predict_cdf(model, x, t_horizon, risk='1'):
Expand Down

0 comments on commit 8dd7c7b

Please sign in to comment.