-
Notifications
You must be signed in to change notification settings - Fork 82
/
make_folds.py
74 lines (55 loc) · 1.85 KB
/
make_folds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import argparse
import collections
import pickle
from pprint import pprint
import random
import numpy as np
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input')
parser.add_argument('--output')
parser.add_argument('--n-fold', type=int, default=5)
parser.add_argument('--seed', type=int, default=10)
return parser.parse_args()
def _make_folds(df, n_fold, seed):
counter_gt = collections.defaultdict(int)
for labels in df.labels.str.split():
for label in labels:
counter_gt[label] += 1
counter_folds = collections.Counter()
folds = {}
min_labels = {}
random.seed(seed)
groups = df.groupby('PatientID')
print('making %d folds...' % n_fold)
for patient_id, group in tqdm(groups, total=len(groups)):
labels = []
for row in group.itertuples():
for label in row.labels.split():
labels.append(label)
if not labels:
labels = ['']
count_labels = [counter_gt[label] for label in labels]
min_label = labels[np.argmin(count_labels)]
count_folds = [(f, counter_folds[(f, min_label)]) for f in range(n_fold)]
min_count = min([count for f,count in count_folds])
fold = random.choice([f for f,count in count_folds if count == min_count])
folds[patient_id] = fold
for label in labels:
counter_folds[(fold,label)] += 1
pprint(counter_folds)
return folds
def main():
args = get_args()
with open(args.input, 'rb') as f:
df = pickle.load(f)
folds = _make_folds(df, args.n_fold, args.seed)
df['fold'] = df.PatientID.map(folds)
with open(args.output, 'wb') as f:
pickle.dump(df, f)
print('saved to %s' % args.output)
if __name__ == '__main__':
print(sys.argv)
main()