@@ -31,19 +31,24 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
31
31
< span > Expand source code</ span >
32
32
</ summary >
33
33
< pre > < code class ="python "> # coding=utf-8
34
- # Copyright 2020 Chirag Nagpal, Auton Lab.
34
+ # Copyright 2020 Chirag Nagpal
35
35
#
36
- # Licensed under the Apache License, Version 2.0 (the "License");
37
- # you may not use this file except in compliance with the License.
38
- # You may obtain a copy of the License at
39
- #
40
- # http://www.apache.org/licenses/LICENSE-2.0
41
- #
42
- # Unless required by applicable law or agreed to in writing, software
43
- # distributed under the License is distributed on an "AS IS" BASIS,
44
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45
- # See the License for the specific language governing permissions and
46
- # limitations under the License.
36
+ # This file is part of Deep Survival Machines.
37
+
38
+ # Deep Survival Machines is free software: you can redistribute it and/or modify
39
+ # it under the terms of the GNU General Public License as published by
40
+ # the Free Software Foundation, either version 3 of the License, or
41
+ # (at your option) any later version.
42
+
43
+ # Deep Survival Machines is distributed in the hope that it will be useful,
44
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
45
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
46
+ # GNU General Public License for more details.
47
+
48
+ # You should have received a copy of the GNU General Public License
49
+ # along with Deep Survival Machines.
50
+ # If not, see <https://www.gnu.org/licenses/>.
51
+
47
52
48
53
"""Utility functions to load standard datasets to train and evaluate the
49
54
Deep Survival Machines models.
@@ -75,21 +80,58 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
75
80
76
81
return e, t
77
82
78
- def _load_pbc_dataset():
83
+ def _load_pbc_dataset(sequential ):
79
84
"""Helper function to load and preprocess the PBC dataset
80
85
81
86
The Primary biliary cirrhosis (PBC) Dataset [1] is well known
82
87
dataset for evaluating survival analysis models with time
83
88
dependent covariates.
84
89
90
+ Parameters
91
+ ----------
92
+ sequential: bool
93
+ If True returns a list of np.arrays for each individual.
94
+ else, returns collapsed results for each time step. To train
95
+ recurrent neural models you would typically use True.
96
+
97
+
85
98
References
86
99
----------
87
100
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and
88
101
survival analysis. Vol. 169. John Wiley & Sons, 2011.
89
102
90
103
"""
91
104
92
- raise NotImplementedError('')
105
+ data = pkgutil.get_data(__name__, 'datasets/pbc2.csv')
106
+ data = pd.read_csv(io.BytesIO(data))
107
+
108
+ data['histologic'] = data['histologic'].astype(str)
109
+ dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly',
110
+ 'spiders', 'edema', 'histologic']]
111
+ dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline',
112
+ 'SGOT', 'platelets', 'prothrombin']]
113
+ age = data['age'] + data['years']
114
+
115
+ x1 = pd.get_dummies(dat_cat).values
116
+ x2 = dat_num.values
117
+ x3 = age.values.reshape(-1, 1)
118
+ x = np.hstack([x1, x2, x3])
119
+
120
+ time = (data['years'] - data['year']).values
121
+ event = data['status2'].values
122
+
123
+ x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
124
+ x_ = StandardScaler().fit_transform(x)
125
+
126
+ if not sequential:
127
+ return x_, time, event
128
+ else:
129
+ x, t, e = [], [], []
130
+ for id_ in sorted(list(set(data['id']))):
131
+ x.append(x_[data['id'] == id_])
132
+ t.append(time[data['id'] == id_])
133
+ e.append(event[data['id'] == id_])
134
+ return x, t, e
93
135
94
136
def _load_support_dataset():
95
137
"""Helper function to load and preprocess the SUPPORT dataset.
@@ -128,13 +170,16 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
128
170
return x[remove], t[remove], e[remove]
129
171
130
172
131
- def load_dataset(dataset='SUPPORT'):
173
+ def load_dataset(dataset='SUPPORT', **kwargs ):
132
174
"""Helper function to load datasets to test Survival Analysis models.
133
175
134
176
Parameters
135
177
----------
136
178
dataset: str
137
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
179
+ The choice of dataset to load. Currently implemented is 'SUPPORT'
180
+ and 'PBC'.
181
+ **kwargs: dict
182
+ Dataset specific keyword arguments.
138
183
139
184
Returns
140
185
----------
@@ -146,6 +191,9 @@ <h1 class="title">Module <code>dsm.datasets</code></h1>
146
191
147
192
if dataset == 'SUPPORT':
148
193
return _load_support_dataset()
194
+ if dataset == 'PBC':
195
+ sequential = kwargs.get('sequential', False)
196
+ return _load_pbc_dataset(sequential)
149
197
else:
150
198
return NotImplementedError('Dataset '+dataset+' not implemented.')</ code > </ pre >
151
199
</ details >
@@ -184,14 +232,17 @@ <h2 class="section-title" id="header-functions">Functions</h2>
184
232
</ details >
185
233
</ dd >
186
234
< dt id ="dsm.datasets.load_dataset "> < code class ="name flex ">
187
- < span > def < span class ="ident "> load_dataset</ span > </ span > (< span > dataset='SUPPORT')</ span >
235
+ < span > def < span class ="ident "> load_dataset</ span > </ span > (< span > dataset='SUPPORT', **kwargs )</ span >
188
236
</ code > </ dt >
189
237
< dd >
190
238
< div class ="desc "> < p > Helper function to load datasets to test Survival Analysis models.</ p >
191
239
< h2 id ="parameters "> Parameters</ h2 >
192
240
< dl >
193
241
< dt > < strong > < code > dataset</ code > </ strong > : < code > str</ code > </ dt >
194
- < dd > The choice of dataset to load. Currently implemented is 'SUPPORT'.</ dd >
242
+ < dd > The choice of dataset to load. Currently implemented is 'SUPPORT'
243
+ and 'PBC'.</ dd >
244
+ < dt > < strong > < code > **kwargs</ code > </ strong > : < code > dict</ code > </ dt >
245
+ < dd > Dataset specific keyword arguments.</ dd >
195
246
</ dl >
196
247
< h2 id ="returns "> Returns</ h2 >
197
248
< dl >
@@ -203,13 +254,16 @@ <h2 id="returns">Returns</h2>
203
254
< summary >
204
255
< span > Expand source code</ span >
205
256
</ summary >
206
- < pre > < code class ="python "> def load_dataset(dataset='SUPPORT'):
257
+ < pre > < code class ="python "> def load_dataset(dataset='SUPPORT', **kwargs ):
207
258
"""Helper function to load datasets to test Survival Analysis models.
208
259
209
260
Parameters
210
261
----------
211
262
dataset: str
212
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
263
+ The choice of dataset to load. Currently implemented is 'SUPPORT'
264
+ and 'PBC'.
265
+ **kwargs: dict
266
+ Dataset specific keyword arguments.
213
267
214
268
Returns
215
269
----------
@@ -221,6 +275,9 @@ <h2 id="returns">Returns</h2>
221
275
222
276
if dataset == 'SUPPORT':
223
277
return _load_support_dataset()
278
+ if dataset == 'PBC':
279
+ sequential = kwargs.get('sequential', False)
280
+ return _load_pbc_dataset(sequential)
224
281
else:
225
282
return NotImplementedError('Dataset '+dataset+' not implemented.')</ code > </ pre >
226
283
</ details >
0 commit comments