-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathomnifold.py
140 lines (116 loc) · 5.6 KB
/
omnifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import gc
import os
import sys
import time
import energyflow as ef
import numpy as np
# DCTR, reweights positive distribution to negative distribution
# X: features
# Y: categorical labels
# model: model with fit/predict
# fitargs: model fit arguments
def reweight(X, Y, w, model, filepath, fitargs, val_data=None):
# permute the data, fit the model, and get preditions
#perm = np.random.permutation(len(X))
#model.fit(X[perm], Y[perm], sample_weight=w[perm], **fitargs)
val_dict = {'validation_data': val_data} if val_data is not None else {}
model.fit(X, Y, sample_weight=w, **fitargs, **val_dict)
model.save_weights(filepath)
preds = model.predict(X, batch_size=10*fitargs.get('batch_size', 500))[:,1]
# concatenate validation predictions into training predictions
if val_data is not None:
preds_val = model.predict(val_data[0], batch_size=10*fitargs.get('batch_size', 500))[:,1]
preds = np.concatenate((preds, preds_val))
w = np.concatenate((w, val_data[2]))
w *= np.clip(preds/(1 - preds + 10**-50), fitargs.get('weight_clip_min', 0.), fitargs.get('weight_clip_max', np.inf))
return w
# OmniFold
# X_gen/Y_gen: particle level features/labels
# X_det/Y_det: detector level features/labels, note these should be ordered as (data, sim)
# wdata/winit: initial weights of the data/simulation
# model: model with fit/predict
# fitargs: model fit arguments
# it: number of iterations
# trw_ind: which previous weights to use in second step, 0 means use initial, -2 means use previous
def omnifold(X_gen_i, Y_gen_i, X_det_i, Y_det_i, wdata, winit, det_model, mc_model, fitargs,
val=0.2, it=10, weights_filename=None, trw_ind=0, delete_global_arrays=False):
# get arrays (possibly globally)
X_gen_arr = globals()[X_gen_i] if isinstance(X_gen_i, str) else X_gen_i
Y_gen_arr = globals()[Y_gen_i] if isinstance(Y_gen_i, str) else Y_gen_i
X_det_arr = globals()[X_det_i] if isinstance(X_det_i, str) else X_det_i
Y_det_arr = globals()[Y_det_i] if isinstance(Y_det_i, str) else Y_det_i
# initialize the truth weights to the prior
ws = [winit]
# get permutation for det
perm_det = np.random.permutation(len(winit) + len(wdata))
invperm_det = np.argsort(perm_det)
nval_det = int(val*len(perm_det))
X_det_train, X_det_val = X_det_arr[perm_det[:-nval_det]], X_det_arr[perm_det[-nval_det:]]
Y_det_train, Y_det_val = Y_det_arr[perm_det[:-nval_det]], Y_det_arr[perm_det[-nval_det:]]
# remove X_det, Y_det
if delete_global_arrays:
del X_det_arr, Y_det_arr
if isinstance(X_det_i, str):
del globals()[X_det_i]
if isinstance(Y_det_i, str):
del globals()[Y_det_i]
# get an initial permutation for gen and duplicate (offset) it
nval = int(val*len(winit))
baseperm0 = np.random.permutation(len(winit))
baseperm1 = baseperm0 + len(winit)
# training examples are at beginning, val at end
# concatenate into single train and val perms (shuffle each)
trainperm = np.concatenate((baseperm0[:-nval], baseperm1[:-nval]))
valperm = np.concatenate((baseperm0[-nval:], baseperm1[-nval:]))
np.random.shuffle(trainperm)
np.random.shuffle(valperm)
# get final permutation for gen (ensured that the same events end up in val)
perm_gen = np.concatenate((trainperm, valperm))
invperm_gen = np.argsort(perm_gen)
nval_gen = int(val*len(perm_gen))
X_gen_train, X_gen_val = X_gen_arr[perm_gen[:-nval_gen]], X_gen_arr[perm_gen[-nval_gen:]]
Y_gen_train, Y_gen_val = Y_gen_arr[perm_gen[:-nval_gen]], Y_gen_arr[perm_gen[-nval_gen:]]
# remove X_gen, Y_gen
if delete_global_arrays:
del X_gen_arr, Y_gen_arr
if isinstance(X_gen_i, str):
del globals()[X_gen_i]
if isinstance(Y_gen_i, str):
del globals()[Y_gen_i]
# store model filepaths
model_det_fp, model_mc_fp = det_model[1].get('filepath', None), mc_model[1].get('filepath', None)
# iterate the procedure
for i in range(it):
# det filepaths properly
if model_det_fp is not None:
model_det_fp_i = model_det_fp.format(i)
det_model[1]['filepath'] = model_det_fp_i + '_Epoch-{epoch}'
if model_mc_fp is not None:
model_mc_fp_i = model_mc_fp.format(i)
mc_model[1]['filepath'] = model_mc_fp_i + '_Epoch-{epoch}'
# define models
model_det = det_model[0](**det_model[1])
model_mc = mc_model[0](**mc_model[1])
# load weights if not model 0
if i > 0:
model_det.load_weights(model_det_fp.format(i-1))
model_mc.load_weights(model_mc_fp.format(i-1))
# step 1: reweight sim to look like data
w = np.concatenate((wdata, ws[-1]))
w_train, w_val = w[perm_det[:-nval_det]], w[perm_det[-nval_det:]]
rw = reweight(X_det_train, Y_det_train, w_train, model_det, model_det_fp_i,
fitargs, val_data=(X_det_val, Y_det_val, w_val))[invperm_det]
ws.append(rw[len(wdata):])
# step 2: reweight the prior to the learned weighting
w = np.concatenate((ws[-1], ws[trw_ind]))
w_train, w_val = w[perm_gen[:-nval_gen]], w[perm_gen[-nval_gen:]]
rw = reweight(X_gen_train, Y_gen_train, w_train, model_mc, model_mc_fp_i,
fitargs, val_data=(X_gen_val, Y_gen_val, w_val))[invperm_gen]
ws.append(rw[len(ws[-1]):])
# save the weights if specified
if weights_filename is not None:
np.save(weights_filename, ws)
return ws
if __name__ == '__main__':
main(sys.argv[1:])