diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index c5ac790..51d0dc4 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -275,6 +275,34 @@ def predict_survival(self, x, t, risk=1): "model using the `fit` method on some training data " + "before calling `predict_survival`.") + def predict_pdf(self, x, t, risk=1): + r"""Returns the estimated pdf at time \( t \), + \( \widehat{\mathbb{P}}(T = t|X) \) for some input data \( x \). + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: list or float + a list or float of the times at which pdf is + to be computed + Returns: + np.array: numpy array of the estimated pdf at each time in t. + + """ + x = self._prepocess_test_data(x) + if not isinstance(t, list): + t = [t] + if self.fitted: + scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk)) + return np.exp(np.array(scores)).T + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_survival`.") + + + class DeepSurvivalMachines(DSMBase): """A Deep Survival Machines model. diff --git a/dsm/losses.py b/dsm/losses.py index 915b1c4..1cbc202 100644 --- a/dsm/losses.py +++ b/dsm/losses.py @@ -293,6 +293,40 @@ def conditional_loss(model, x, t, e, elbo=True, risk='1'): raise NotImplementedError('Distribution: '+model.dist+ ' not implemented yet.') +def _weibull_pdf(model, x, t_horizon, risk='1'): + + squish = nn.LogSoftmax(dim=1) + + shape, scale, logits = model.forward(x, risk) + logits = squish(logits) + + k_ = shape + b_ = scale + + t_horz = torch.tensor(t_horizon).double() + t_horz = t_horz.repeat(shape.shape[0], 1) + + pdfs = [] + for j in range(len(t_horizon)): + + t = t_horz[:, j] + lpdfs = [] + + for g in range(model.k): + + k = k_[:, g] + b = b_[:, g] + s = - (torch.pow(torch.exp(b)*t, torch.exp(k))) + f = k + b + ((torch.exp(k)-1)*(b+torch.log(t))) + f = f + s + lpdfs.append(f) + + lpdfs = torch.stack(lpdfs, dim=1) + lpdfs = lpdfs+logits + lpdfs = torch.logsumexp(lpdfs, dim=1) + pdfs.append(lpdfs.detach().numpy()) + + return lpdfs def _weibull_cdf(model, x, t_horizon, risk='1'): @@ -327,6 +361,35 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): return cdfs +def _weibull_mean(model, x, risk='1'): + + squish = nn.LogSoftmax(dim=1) + + shape, scale, logits = model.forward(x, risk) + logits = squish(logits) + + k_ = shape + b_ = scale + + lmeans = [] + + for g in range(model.k): + + k = k_[:, g] + b = b_[:, g] + + one_over_k = torch.reciprocal(torch.exp(k)) + lmean = -(one_over_k*b) + torch.lgamma(1+one_over_k) + lmeans.append(lmean) + + lmeans = torch.stack(lmeans, dim=1) + lmeans = lmeans+logits + lmeans = torch.logsumexp(lmeans, dim=1) + + return torch.exp(lmeans).detach().numpy() + + + def _lognormal_cdf(model, x, t_horizon, risk='1'): @@ -424,15 +487,29 @@ def _normal_mean(model, x, risk='1'): return lmeans.detach().numpy() + def predict_mean(model, x, risk='1'): torch.no_grad() if model.dist == 'Normal': return _normal_mean(model, x, risk) + elif model.dist == 'Weibull': + return _weibull_mean(model, x, risk) else: raise NotImplementedError('Mean of Distribution: '+model.dist+ ' not implemented yet.') +def predict_pdf(model, x, t_horizon, risk='1'): + torch.no_grad() + if model.dist == 'Weibull': + return _weibull_pdf(model, x, t_horizon, risk) + # if model.dist == 'LogNormal': + # return _lognormal_pdf(model, x, t_horizon, risk) + # if model.dist == 'Normal': + # return _normal_pdf(model, x, t_horizon, risk) + else: + raise NotImplementedError('Distribution: '+model.dist+ + ' not implemented yet.') def predict_cdf(model, x, t_horizon, risk='1'):