-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathdatasets.py
218 lines (165 loc) · 6.67 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# coding=utf-8
# Copyright 2020 Chirag Nagpal
#
# This file is part of Deep Survival Machines.
# Deep Survival Machines is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# Deep Survival Machines is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.
"""Utility functions to load standard datasets to train and evaluate the
Deep Survival Machines models.
"""
import io
import pkgutil
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
def increase_censoring(e, t, p):
uncens = np.where(e == 1)[0]
mask = np.random.choice([False, True], len(uncens), p=[1-p, p])
toswitch = uncens[mask]
e[toswitch] = 0
t_ = t[toswitch]
newt = []
for t__ in t_:
newt.append(np.random.uniform(1, t__))
t[toswitch] = newt
return e, t
def _load_framingham_dataset(sequential):
"""Helper function to load and preprocess the Framingham dataset.
The Framingham Dataset is a subset of 4,434 participants of the well known,
ongoing Framingham Heart study [1] for studying epidemiology for
hypertensive and arteriosclerotic cardiovascular disease. It is a popular
dataset for longitudinal survival analysis with time dependent covariates.
Parameters
----------
sequential: bool
If True returns a list of np.arrays for each individual.
else, returns collapsed results for each time step. To train
recurrent neural models you would typically use True.
References
----------
[1] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr.
"Epidemiological approaches to heart disease: the Framingham Study."
American Journal of Public Health and the Nations Health 41.3 (1951).
"""
data = pkgutil.get_data(__name__, 'datasets/framingham.csv')
data = pd.read_csv(io.BytesIO(data))
dat_cat = data[['SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS',
'educ', 'PREVCHD', 'PREVAP', 'PREVMI',
'PREVSTRK', 'PREVHYP']]
dat_num = data[['TOTCHOL', 'AGE', 'SYSBP', 'DIABP',
'CIGPDAY', 'BMI', 'HEARTRTE', 'GLUCOSE']]
x1 = pd.get_dummies(dat_cat).values
x2 = dat_num.values
x = np.hstack([x1, x2])
time = (data['TIMEDTH'] - data['TIME']).values
event = data['DEATH'].values
x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
x_ = StandardScaler().fit_transform(x)
if not sequential:
return x_, time, event
else:
x, t, e = [], [], []
for id_ in sorted(list(set(data['RANDID']))):
x.append(x_[data['RANDID'] == id_])
t.append(time[data['RANDID'] == id_])
e.append(event[data['RANDID'] == id_])
return x, t, e
def _load_pbc_dataset(sequential):
"""Helper function to load and preprocess the PBC dataset
The Primary biliary cirrhosis (PBC) Dataset [1] is well known
dataset for evaluating survival analysis models with time
dependent covariates.
Parameters
----------
sequential: bool
If True returns a list of np.arrays for each individual.
else, returns collapsed results for each time step. To train
recurrent neural models you would typically use True.
References
----------
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and
survival analysis. Vol. 169. John Wiley & Sons, 2011.
"""
data = pkgutil.get_data(__name__, 'datasets/pbc2.csv')
data = pd.read_csv(io.BytesIO(data))
data['histologic'] = data['histologic'].astype(str)
dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly',
'spiders', 'edema', 'histologic']]
dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline',
'SGOT', 'platelets', 'prothrombin']]
age = data['age'] + data['years']
x1 = pd.get_dummies(dat_cat).values
x2 = dat_num.values
x3 = age.values.reshape(-1, 1)
x = np.hstack([x1, x2, x3])
time = (data['years'] - data['year']).values
event = data['status2'].values
x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
x_ = StandardScaler().fit_transform(x)
if not sequential:
return x_, time, event
else:
x, t, e = [], [], []
for id_ in sorted(list(set(data['id']))):
x.append(x_[data['id'] == id_])
t.append(time[data['id'] == id_])
e.append(event[data['id'] == id_])
return x, t, e
def _load_support_dataset():
"""Helper function to load and preprocess the SUPPORT dataset.
The SUPPORT Dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults [1].
Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.
References
----------
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
model: Objective estimates of survival for seriously ill hospitalized
adults. Annals of Internal Medicine 122:191-203.
"""
data = pkgutil.get_data(__name__, 'datasets/support2.csv')
data = pd.read_csv(io.BytesIO(data))
x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
'urine', 'adlp', 'adls']]
catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
x2 = pd.get_dummies(data[catfeats])
x = np.concatenate([x1, x2], axis=1)
t = data['d.time'].values
e = data['death'].values
x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
x = StandardScaler().fit_transform(x)
remove = ~np.isnan(t)
return x[remove], t[remove], e[remove]
def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Parameters
----------
dataset: str
The choice of dataset to load. Currently implemented is 'SUPPORT'
and 'PBC'.
**kwargs: dict
Dataset specific keyword arguments.
Returns
----------
tuple: (np.ndarray, np.ndarray, np.ndarray)
A tuple of the form of (x, t, e) where x, t, e are the input covariates,
event times and the censoring indicators respectively.
"""
if dataset == 'SUPPORT':
return _load_support_dataset()
if dataset == 'PBC':
sequential = kwargs.get('sequential', False)
return _load_pbc_dataset(sequential)
else:
return NotImplementedError('Dataset '+dataset+' not implemented.')