|
32 | 32 | from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
|
33 | 33 | from dsm.losses import predict_cdf
|
34 | 34 | import dsm.losses as losses
|
35 |
| -from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets |
| 35 | +from dsm.utilities import train_dsm |
| 36 | +from dsm.utilities import _get_padded_features, _get_padded_targets |
| 37 | +from dsm.utilities import _reshape_tensor_with_nans |
36 | 38 |
|
37 | 39 | import torch
|
38 | 40 | import numpy as np
|
@@ -120,6 +122,40 @@ def fit(self, x, t, e, vsize=0.15,
|
120 | 122 |
|
121 | 123 | return self
|
122 | 124 |
|
| 125 | + def _eval_nll(self, x, t, e): |
| 126 | + r"""This function computes the negative log likelihood of the given data. |
| 127 | + In case of competing risks, the negative log likelihoods are summed over |
| 128 | + the different events' type. |
| 129 | +
|
| 130 | + Parameters |
| 131 | + ---------- |
| 132 | + x: np.ndarray |
| 133 | + A numpy array of the input features, \( x \). |
| 134 | + t: np.ndarray |
| 135 | + A numpy array of the event/censoring times, \( t \). |
| 136 | + e: np.ndarray |
| 137 | + A numpy array of the event/censoring indicators, \( \delta \). |
| 138 | + \( \delta = r \) means the event r took place. |
| 139 | + |
| 140 | + Returns: |
| 141 | + float: Negative log likelihood. |
| 142 | + """ |
| 143 | + if not(self.fitted): |
| 144 | + raise Exception("The model has not been fitted yet. Please fit the " + |
| 145 | + "model using the `fit` method on some training data " + |
| 146 | + "before calling `_eval_nll`.") |
| 147 | + processed_data = self._prepocess_training_data(x, t, e, 0, 0) |
| 148 | + _, _, _, x_val, t_val, e_val = processed_data |
| 149 | + x_val, t_val, e_val = x_val,\ |
| 150 | + _reshape_tensor_with_nans(t_val),\ |
| 151 | + _reshape_tensor_with_nans(e_val) |
| 152 | + loss = 0 |
| 153 | + for r in range(self.torch_model.risks): |
| 154 | + loss += float(losses.conditional_loss(self.torch_model, |
| 155 | + x_val, t_val, e_val, elbo=False, |
| 156 | + risk=str(r+1)).detach().numpy()) |
| 157 | + return loss |
| 158 | + |
123 | 159 | def _prepocess_test_data(self, x):
|
124 | 160 | return torch.from_numpy(x)
|
125 | 161 |
|
|
0 commit comments