Skip to content

Commit fb94ae7

Browse files
committed
modified: .travis.yml
modified: docs/datasets.html deleted: docs/datautils.html modified: docs/dsm_api.html modified: docs/dsm_torch.html modified: docs/index.html modified: docs/losses.html modified: docs/utilities.html
1 parent 9f82dcc commit fb94ae7

8 files changed

+635
-464
lines changed

.travis.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ python:
55
- "3.8"
66
os:
77
- linux
8-
- osx
98
# command to install dependencies
109
install:
1110
- pip install -r requirements.txt
@@ -15,4 +14,4 @@ install:
1514
# command to run tests
1615
script:
1716
- python -m pytest tests/
18-
- pylint --fail-under=9 dsm/
17+
- pylint --fail-under=8 dsm/

docs/datasets.html

+77-20
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,24 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
3131
<span>Expand source code</span>
3232
</summary>
3333
<pre><code class="python"># coding=utf-8
34-
# Copyright 2020 Chirag Nagpal, Auton Lab.
34+
# Copyright 2020 Chirag Nagpal
3535
#
36-
# Licensed under the Apache License, Version 2.0 (the &#34;License&#34;);
37-
# you may not use this file except in compliance with the License.
38-
# You may obtain a copy of the License at
39-
#
40-
# http://www.apache.org/licenses/LICENSE-2.0
41-
#
42-
# Unless required by applicable law or agreed to in writing, software
43-
# distributed under the License is distributed on an &#34;AS IS&#34; BASIS,
44-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45-
# See the License for the specific language governing permissions and
46-
# limitations under the License.
36+
# This file is part of Deep Survival Machines.
37+
38+
# Deep Survival Machines is free software: you can redistribute it and/or modify
39+
# it under the terms of the GNU General Public License as published by
40+
# the Free Software Foundation, either version 3 of the License, or
41+
# (at your option) any later version.
42+
43+
# Deep Survival Machines is distributed in the hope that it will be useful,
44+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
45+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
46+
# GNU General Public License for more details.
47+
48+
# You should have received a copy of the GNU General Public License
49+
# along with Deep Survival Machines.
50+
# If not, see &lt;https://www.gnu.org/licenses/&gt;.
51+
4752

4853
&#34;&#34;&#34;Utility functions to load standard datasets to train and evaluate the
4954
Deep Survival Machines models.
@@ -75,21 +80,58 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
7580

7681
return e, t
7782

78-
def _load_pbc_dataset():
83+
def _load_pbc_dataset(sequential):
7984
&#34;&#34;&#34;Helper function to load and preprocess the PBC dataset
8085

8186
The Primary biliary cirrhosis (PBC) Dataset [1] is well known
8287
dataset for evaluating survival analysis models with time
8388
dependent covariates.
8489

90+
Parameters
91+
----------
92+
sequential: bool
93+
If True returns a list of np.arrays for each individual.
94+
else, returns collapsed results for each time step. To train
95+
recurrent neural models you would typically use True.
96+
97+
8598
References
8699
----------
87100
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and
88101
survival analysis. Vol. 169. John Wiley &amp; Sons, 2011.
89102

90103
&#34;&#34;&#34;
91104

92-
raise NotImplementedError(&#39;&#39;)
105+
data = pkgutil.get_data(__name__, &#39;datasets/pbc2.csv&#39;)
106+
data = pd.read_csv(io.BytesIO(data))
107+
108+
data[&#39;histologic&#39;] = data[&#39;histologic&#39;].astype(str)
109+
dat_cat = data[[&#39;drug&#39;, &#39;sex&#39;, &#39;ascites&#39;, &#39;hepatomegaly&#39;,
110+
&#39;spiders&#39;, &#39;edema&#39;, &#39;histologic&#39;]]
111+
dat_num = data[[&#39;serBilir&#39;, &#39;serChol&#39;, &#39;albumin&#39;, &#39;alkaline&#39;,
112+
&#39;SGOT&#39;, &#39;platelets&#39;, &#39;prothrombin&#39;]]
113+
age = data[&#39;age&#39;] + data[&#39;years&#39;]
114+
115+
x1 = pd.get_dummies(dat_cat).values
116+
x2 = dat_num.values
117+
x3 = age.values.reshape(-1, 1)
118+
x = np.hstack([x1, x2, x3])
119+
120+
time = (data[&#39;years&#39;] - data[&#39;year&#39;]).values
121+
event = data[&#39;status2&#39;].values
122+
123+
x = SimpleImputer(missing_values=np.nan, strategy=&#39;mean&#39;).fit_transform(x)
124+
x_ = StandardScaler().fit_transform(x)
125+
126+
if not sequential:
127+
return x_, time, event
128+
else:
129+
x, t, e = [], [], []
130+
for id_ in sorted(list(set(data[&#39;id&#39;]))):
131+
x.append(x_[data[&#39;id&#39;] == id_])
132+
t.append(time[data[&#39;id&#39;] == id_])
133+
e.append(event[data[&#39;id&#39;] == id_])
134+
return x, t, e
93135

94136
def _load_support_dataset():
95137
&#34;&#34;&#34;Helper function to load and preprocess the SUPPORT dataset.
@@ -128,13 +170,16 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
128170
return x[remove], t[remove], e[remove]
129171

130172

131-
def load_dataset(dataset=&#39;SUPPORT&#39;):
173+
def load_dataset(dataset=&#39;SUPPORT&#39;, **kwargs):
132174
&#34;&#34;&#34;Helper function to load datasets to test Survival Analysis models.
133175

134176
Parameters
135177
----------
136178
dataset: str
137-
The choice of dataset to load. Currently implemented is &#39;SUPPORT&#39;.
179+
The choice of dataset to load. Currently implemented is &#39;SUPPORT&#39;
180+
and &#39;PBC&#39;.
181+
**kwargs: dict
182+
Dataset specific keyword arguments.
138183

139184
Returns
140185
----------
@@ -146,6 +191,9 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
146191

147192
if dataset == &#39;SUPPORT&#39;:
148193
return _load_support_dataset()
194+
if dataset == &#39;PBC&#39;:
195+
sequential = kwargs.get(&#39;sequential&#39;, False)
196+
return _load_pbc_dataset(sequential)
149197
else:
150198
return NotImplementedError(&#39;Dataset &#39;+dataset+&#39; not implemented.&#39;)</code></pre>
151199
</details>
@@ -184,14 +232,17 @@ <h2 class="section-title" id="header-functions">Functions</h2>
184232
</details>
185233
</dd>
186234
<dt id="dsm.datasets.load_dataset"><code class="name flex">
187-
<span>def <span class="ident">load_dataset</span></span>(<span>dataset='SUPPORT')</span>
235+
<span>def <span class="ident">load_dataset</span></span>(<span>dataset='SUPPORT', **kwargs)</span>
188236
</code></dt>
189237
<dd>
190238
<div class="desc"><p>Helper function to load datasets to test Survival Analysis models.</p>
191239
<h2 id="parameters">Parameters</h2>
192240
<dl>
193241
<dt><strong><code>dataset</code></strong> :&ensp;<code>str</code></dt>
194-
<dd>The choice of dataset to load. Currently implemented is 'SUPPORT'.</dd>
242+
<dd>The choice of dataset to load. Currently implemented is 'SUPPORT'
243+
and 'PBC'.</dd>
244+
<dt><strong><code>**kwargs</code></strong> :&ensp;<code>dict</code></dt>
245+
<dd>Dataset specific keyword arguments.</dd>
195246
</dl>
196247
<h2 id="returns">Returns</h2>
197248
<dl>
@@ -203,13 +254,16 @@ <h2 id="returns">Returns</h2>
203254
<summary>
204255
<span>Expand source code</span>
205256
</summary>
206-
<pre><code class="python">def load_dataset(dataset=&#39;SUPPORT&#39;):
257+
<pre><code class="python">def load_dataset(dataset=&#39;SUPPORT&#39;, **kwargs):
207258
&#34;&#34;&#34;Helper function to load datasets to test Survival Analysis models.
208259

209260
Parameters
210261
----------
211262
dataset: str
212-
The choice of dataset to load. Currently implemented is &#39;SUPPORT&#39;.
263+
The choice of dataset to load. Currently implemented is &#39;SUPPORT&#39;
264+
and &#39;PBC&#39;.
265+
**kwargs: dict
266+
Dataset specific keyword arguments.
213267

214268
Returns
215269
----------
@@ -221,6 +275,9 @@ <h2 id="returns">Returns</h2>
221275

222276
if dataset == &#39;SUPPORT&#39;:
223277
return _load_support_dataset()
278+
if dataset == &#39;PBC&#39;:
279+
sequential = kwargs.get(&#39;sequential&#39;, False)
280+
return _load_pbc_dataset(sequential)
224281
else:
225282
return NotImplementedError(&#39;Dataset &#39;+dataset+&#39; not implemented.&#39;)</code></pre>
226283
</details>

0 commit comments

Comments
 (0)