Skip to content

Commit 8c08196

Browse files
authored
Evaluation updated to competing risks (#29)
* Implemented computation of Negative Log-Likelihood.
1 parent 0b9404a commit 8c08196

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

dsm/dsm_api.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
3333
from dsm.losses import predict_cdf
3434
import dsm.losses as losses
35-
from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets
35+
from dsm.utilities import train_dsm
36+
from dsm.utilities import _get_padded_features, _get_padded_targets
37+
from dsm.utilities import _reshape_tensor_with_nans
3638

3739
import torch
3840
import numpy as np
@@ -120,6 +122,40 @@ def fit(self, x, t, e, vsize=0.15,
120122

121123
return self
122124

125+
def _eval_nll(self, x, t, e):
126+
r"""This function computes the negative log likelihood of the given data.
127+
In case of competing risks, the negative log likelihoods are summed over
128+
the different events' type.
129+
130+
Parameters
131+
----------
132+
x: np.ndarray
133+
A numpy array of the input features, \( x \).
134+
t: np.ndarray
135+
A numpy array of the event/censoring times, \( t \).
136+
e: np.ndarray
137+
A numpy array of the event/censoring indicators, \( \delta \).
138+
\( \delta = r \) means the event r took place.
139+
140+
Returns:
141+
float: Negative log likelihood.
142+
"""
143+
if not(self.fitted):
144+
raise Exception("The model has not been fitted yet. Please fit the " +
145+
"model using the `fit` method on some training data " +
146+
"before calling `_eval_nll`.")
147+
processed_data = self._prepocess_training_data(x, t, e, 0, 0)
148+
_, _, _, x_val, t_val, e_val = processed_data
149+
x_val, t_val, e_val = x_val,\
150+
_reshape_tensor_with_nans(t_val),\
151+
_reshape_tensor_with_nans(e_val)
152+
loss = 0
153+
for r in range(self.torch_model.risks):
154+
loss += float(losses.conditional_loss(self.torch_model,
155+
x_val, t_val, e_val, elbo=False,
156+
risk=str(r+1)).detach().numpy())
157+
return loss
158+
123159
def _prepocess_test_data(self, x):
124160
return torch.from_numpy(x)
125161

0 commit comments

Comments
 (0)