Skip to content

Commit

Permalink
modified: docs/dsm_api.html
Browse files Browse the repository at this point in the history
	modified:   dsm/dsm_api.py
  • Loading branch information
chiragnagpal committed Jan 4, 2021
1 parent 8c08196 commit d60789d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 5 deletions.
84 changes: 84 additions & 0 deletions docs/dsm_api.html
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,33 @@ <h2 id="parameters">Parameters</h2>
<dd>random seed that determines how the validation set is chosen.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepSurvivalMachines.compute_nll"><code class="name flex">
<span>def <span class="ident">compute_nll</span></span>(<span>self, x, t, e)</span>
</code></dt>
<dd>
<p class="inheritance">
<em>Inherited from:</em>
<code><a title="dsm.dsm_api.DSMBase" href="#dsm.dsm_api.DSMBase">DSMBase</a></code>.<code><a title="dsm.dsm_api.DSMBase.compute_nll" href="#dsm.dsm_api.DSMBase.compute_nll">compute_nll</a></code>
</p>
<div class="desc"><p>This function computes the negative log likelihood of the given data.
In case of competing risks, the negative log likelihoods are summed over
the different events' type.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the input features, <span><span class="MathJax_Preview"> x </span><script type="math/tex"> x </script></span>.</dd>
<dt><strong><code>t</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring times, <span><span class="MathJax_Preview"> t </span><script type="math/tex"> t </script></span>.</dd>
<dt><strong><code>e</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring indicators, <span><span class="MathJax_Preview"> \delta </span><script type="math/tex"> \delta </script></span>.
<span><span class="MathJax_Preview"> \delta = r </span><script type="math/tex"> \delta = r </script></span> means the event r took place.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><code>float</code></dt>
<dd>Negative log likelihood.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepSurvivalMachines.predict_mean"><code class="name flex">
<span>def <span class="ident">predict_mean</span></span>(<span>self, x, risk=1)</span>
</code></dt>
Expand Down Expand Up @@ -231,6 +258,33 @@ <h2 id="parameters">Parameters</h2>
<dd>random seed that determines how the validation set is chosen.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepRecurrentSurvivalMachines.compute_nll"><code class="name flex">
<span>def <span class="ident">compute_nll</span></span>(<span>self, x, t, e)</span>
</code></dt>
<dd>
<p class="inheritance">
<em>Inherited from:</em>
<code><a title="dsm.dsm_api.DSMBase" href="#dsm.dsm_api.DSMBase">DSMBase</a></code>.<code><a title="dsm.dsm_api.DSMBase.compute_nll" href="#dsm.dsm_api.DSMBase.compute_nll">compute_nll</a></code>
</p>
<div class="desc"><p>This function computes the negative log likelihood of the given data.
In case of competing risks, the negative log likelihoods are summed over
the different events' type.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the input features, <span><span class="MathJax_Preview"> x </span><script type="math/tex"> x </script></span>.</dd>
<dt><strong><code>t</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring times, <span><span class="MathJax_Preview"> t </span><script type="math/tex"> t </script></span>.</dd>
<dt><strong><code>e</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring indicators, <span><span class="MathJax_Preview"> \delta </span><script type="math/tex"> \delta </script></span>.
<span><span class="MathJax_Preview"> \delta = r </span><script type="math/tex"> \delta = r </script></span> means the event r took place.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><code>float</code></dt>
<dd>Negative log likelihood.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_mean"><code class="name flex">
<span>def <span class="ident">predict_mean</span></span>(<span>self, x, risk=1)</span>
</code></dt>
Expand Down Expand Up @@ -347,6 +401,33 @@ <h2 id="parameters">Parameters</h2>
<dd>random seed that determines how the validation set is chosen.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepConvolutionalSurvivalMachines.compute_nll"><code class="name flex">
<span>def <span class="ident">compute_nll</span></span>(<span>self, x, t, e)</span>
</code></dt>
<dd>
<p class="inheritance">
<em>Inherited from:</em>
<code><a title="dsm.dsm_api.DSMBase" href="#dsm.dsm_api.DSMBase">DSMBase</a></code>.<code><a title="dsm.dsm_api.DSMBase.compute_nll" href="#dsm.dsm_api.DSMBase.compute_nll">compute_nll</a></code>
</p>
<div class="desc"><p>This function computes the negative log likelihood of the given data.
In case of competing risks, the negative log likelihoods are summed over
the different events' type.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the input features, <span><span class="MathJax_Preview"> x </span><script type="math/tex"> x </script></span>.</dd>
<dt><strong><code>t</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring times, <span><span class="MathJax_Preview"> t </span><script type="math/tex"> t </script></span>.</dd>
<dt><strong><code>e</code></strong> :&ensp;<code>np.ndarray</code></dt>
<dd>A numpy array of the event/censoring indicators, <span><span class="MathJax_Preview"> \delta </span><script type="math/tex"> \delta </script></span>.
<span><span class="MathJax_Preview"> \delta = r </span><script type="math/tex"> \delta = r </script></span> means the event r took place.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><code>float</code></dt>
<dd>Negative log likelihood.</dd>
</dl></div>
</dd>
<dt id="dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_mean"><code class="name flex">
<span>def <span class="ident">predict_mean</span></span>(<span>self, x, risk=1)</span>
</code></dt>
Expand Down Expand Up @@ -437,6 +518,7 @@ <h1>Index</h1>
<h4><code><a title="dsm.dsm_api.DeepSurvivalMachines" href="#dsm.dsm_api.DeepSurvivalMachines">DeepSurvivalMachines</a></code></h4>
<ul class="">
<li><code><a title="dsm.dsm_api.DeepSurvivalMachines.fit" href="#dsm.dsm_api.DeepSurvivalMachines.fit">fit</a></code></li>
<li><code><a title="dsm.dsm_api.DeepSurvivalMachines.compute_nll" href="#dsm.dsm_api.DeepSurvivalMachines.compute_nll">compute_nll</a></code></li>
<li><code><a title="dsm.dsm_api.DeepSurvivalMachines.predict_mean" href="#dsm.dsm_api.DeepSurvivalMachines.predict_mean">predict_mean</a></code></li>
<li><code><a title="dsm.dsm_api.DeepSurvivalMachines.predict_risk" href="#dsm.dsm_api.DeepSurvivalMachines.predict_risk">predict_risk</a></code></li>
<li><code><a title="dsm.dsm_api.DeepSurvivalMachines.predict_survival" href="#dsm.dsm_api.DeepSurvivalMachines.predict_survival">predict_survival</a></code></li>
Expand All @@ -446,6 +528,7 @@ <h4><code><a title="dsm.dsm_api.DeepSurvivalMachines" href="#dsm.dsm_api.DeepSur
<h4><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines">DeepRecurrentSurvivalMachines</a></code></h4>
<ul class="">
<li><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines.fit" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines.fit">fit</a></code></li>
<li><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines.compute_nll" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines.compute_nll">compute_nll</a></code></li>
<li><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_mean" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_mean">predict_mean</a></code></li>
<li><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_risk" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_risk">predict_risk</a></code></li>
<li><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_survival" href="#dsm.dsm_api.DeepRecurrentSurvivalMachines.predict_survival">predict_survival</a></code></li>
Expand All @@ -455,6 +538,7 @@ <h4><code><a title="dsm.dsm_api.DeepRecurrentSurvivalMachines" href="#dsm.dsm_ap
<h4><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines">DeepConvolutionalSurvivalMachines</a></code></h4>
<ul class="">
<li><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines.fit" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines.fit">fit</a></code></li>
<li><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines.compute_nll" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines.compute_nll">compute_nll</a></code></li>
<li><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_mean" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_mean">predict_mean</a></code></li>
<li><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_risk" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_risk">predict_risk</a></code></li>
<li><code><a title="dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_survival" href="#dsm.dsm_api.DeepConvolutionalSurvivalMachines.predict_survival">predict_survival</a></code></li>
Expand Down
14 changes: 9 additions & 5 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch
from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
from dsm.losses import predict_cdf

import dsm.losses as losses

from dsm.utilities import train_dsm
from dsm.utilities import _get_padded_features, _get_padded_targets
from dsm.utilities import _reshape_tensor_with_nans
Expand All @@ -40,8 +41,10 @@
import numpy as np

__pdoc__ = {}
__pdoc__["DSMBase"] = False
__pdoc__["DeepSurvivalMachines.fit"] = True
__pdoc__["DeepSurvivalMachines._eval_nll"] = True
__pdoc__["DeepConvolutionalSurvivalMachines._eval_nll"] = True
__pdoc__["DSMBase"] = False


class DSMBase():
Expand Down Expand Up @@ -120,9 +123,10 @@ def fit(self, x, t, e, vsize=0.15,
self.torch_model = model.eval()
self.fitted = True

return self
return self


def _eval_nll(self, x, t, e):
def compute_nll(self, x, t, e):
r"""This function computes the negative log likelihood of the given data.
In case of competing risks, the negative log likelihoods are summed over
the different events' type.
Expand Down Expand Up @@ -243,7 +247,7 @@ def predict_survival(self, x, t, risk=1):
if not isinstance(t, list):
t = [t]
if self.fitted:
scores = predict_cdf(self.torch_model, x, t, risk=str(risk))
scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk))
return np.exp(np.array(scores)).T
else:
raise Exception("The model has not been fitted yet. Please fit the " +
Expand Down

0 comments on commit d60789d

Please sign in to comment.