-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathutils_general.py
304 lines (247 loc) · 12 KB
/
utils_general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from utils_libs import *
from utils_dataset import *
from utils_models import *
# Global parameters
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
max_norm = 10
# --- Evaluate a NN model
def get_acc_loss(data_x, data_y, model, dataset_name, w_decay = None):
acc_overall = 0; loss_overall = 0;
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
batch_size = min(6000, data_x.shape[0])
n_tst = data_x.shape[0]
tst_gen = data.DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=batch_size, shuffle=False)
model.eval(); model = model.to(device)
with torch.no_grad():
tst_gen_iter = tst_gen.__iter__()
for i in range(int(np.ceil(n_tst/batch_size))):
batch_x, batch_y = tst_gen_iter.__next__()
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_pred = model(batch_x)
loss = loss_fn(y_pred, batch_y.reshape(-1).long())
loss_overall += loss.item()
# Accuracy calculation
y_pred = y_pred.cpu().numpy()
y_pred = np.argmax(y_pred, axis=1).reshape(-1)
batch_y = batch_y.cpu().numpy().reshape(-1).astype(np.int32)
batch_correct = np.sum(y_pred == batch_y)
acc_overall += batch_correct
loss_overall /= n_tst
if w_decay != None:
# Add L2 loss
params = get_mdl_params([model], n_par=None)
loss_overall += w_decay/2 * np.sum(params * params)
model.train()
return loss_overall, acc_overall / n_tst
# --- Helper functions
def set_client_from_params(mdl, params):
dict_param = copy.deepcopy(dict(mdl.named_parameters()))
idx = 0
for name, param in mdl.named_parameters():
weights = param.data
length = len(weights.reshape(-1))
dict_param[name].data.copy_(torch.tensor(params[idx:idx+length].reshape(weights.shape)).to(device))
idx += length
mdl.load_state_dict(dict_param)
return mdl
def get_mdl_params(model_list, n_par=None):
if n_par==None:
exp_mdl = model_list[0]
n_par = 0
for name, param in exp_mdl.named_parameters():
n_par += len(param.data.reshape(-1))
param_mat = np.zeros((len(model_list), n_par)).astype('float32')
for i, mdl in enumerate(model_list):
idx = 0
for name, param in mdl.named_parameters():
temp = param.data.cpu().numpy().reshape(-1)
param_mat[i, idx:idx + len(temp)] = temp
idx += len(temp)
return np.copy(param_mat)
# --- Train functions
def train_model(model, trn_x, trn_y, learning_rate, batch_size, epoch, print_per, weight_decay, dataset_name):
n_trn = trn_x.shape[0]
trn_gen = data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), batch_size=batch_size, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model.train(); model = model.to(device)
for e in range(epoch):
# Training
trn_gen_iter = trn_gen.__iter__()
for i in range(int(np.ceil(n_trn/batch_size))):
batch_x, batch_y = trn_gen_iter.__next__()
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_pred = model(batch_x)
loss = loss_fn(y_pred, batch_y.reshape(-1).long())
loss = loss / list(batch_y.size())[0]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_norm) # Clip gradients
optimizer.step()
if (e+1) % print_per == 0:
loss_trn, acc_trn = get_acc_loss(trn_x, trn_y, model, dataset_name, weight_decay)
print("Epoch %3d, Training Accuracy: %.4f, Loss: %.4f" %(e+1, acc_trn, loss_trn))
model.train()
# Freeze model
for params in model.parameters():
params.requires_grad = False
model.eval()
return model
def train_scaffold_mdl(model, model_func, state_params_diff, trn_x, trn_y, learning_rate, batch_size, n_minibatch, print_per, weight_decay, dataset_name):
n_trn = trn_x.shape[0]
trn_gen = data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), batch_size=batch_size, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model.train(); model = model.to(device)
n_par = get_mdl_params([model_func()]).shape[1]
n_iter_per_epoch = int(np.ceil(n_trn/batch_size))
epoch = np.ceil(n_minibatch / n_iter_per_epoch).astype(np.int64)
count_step = 0
is_done = False
step_loss = 0; n_data_step = 0
for e in range(epoch):
# Training
if is_done:
break
trn_gen_iter = trn_gen.__iter__()
for i in range(int(np.ceil(n_trn/batch_size))):
count_step += 1
if count_step > n_minibatch:
is_done = True
break
batch_x, batch_y = trn_gen_iter.__next__()
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_pred = model(batch_x)
## Get f_i estimate
loss_f_i = loss_fn(y_pred, batch_y.reshape(-1).long())
loss_f_i = loss_f_i / list(batch_y.size())[0]
# Get linear penalty on the current parameter estimates
local_par_list = None
for param in model.parameters():
if not isinstance(local_par_list, torch.Tensor):
# Initially nothing to concatenate
local_par_list = param.reshape(-1)
else:
local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)
loss_algo = torch.sum(local_par_list * state_params_diff)
loss = loss_f_i + loss_algo
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_norm) # Clip gradients
optimizer.step()
step_loss += loss.item() * list(batch_y.size())[0]; n_data_step += list(batch_y.size())[0]
if (count_step) % print_per == 0:
step_loss /= n_data_step
if weight_decay != None:
# Add L2 loss to complete f_i
params = get_mdl_params([model], n_par)
step_loss += (weight_decay)/2 * np.sum(params * params)
print("Step %3d, Training Loss: %.4f" %(count_step, step_loss))
step_loss = 0; n_data_step = 0
model.train()
# Freeze model
for params in model.parameters():
params.requires_grad = False
model.eval()
return model
def train_feddyn_mdl(model, model_func, alpha_coef, avg_mdl_param, local_grad_vector, trn_x, trn_y, learning_rate, batch_size, epoch, print_per, weight_decay, dataset_name):
n_trn = trn_x.shape[0]
trn_gen = data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), batch_size=batch_size, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=alpha_coef+weight_decay)
model.train(); model = model.to(device)
n_par = get_mdl_params([model_func()]).shape[1]
for e in range(epoch):
# Training
epoch_loss = 0
trn_gen_iter = trn_gen.__iter__()
for i in range(int(np.ceil(n_trn/batch_size))):
batch_x, batch_y = trn_gen_iter.__next__()
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_pred = model(batch_x)
## Get f_i estimate
loss_f_i = loss_fn(y_pred, batch_y.reshape(-1).long())
loss_f_i = loss_f_i / list(batch_y.size())[0]
# Get linear penalty on the current parameter estimates
local_par_list = None
for param in model.parameters():
if not isinstance(local_par_list, torch.Tensor):
# Initially nothing to concatenate
local_par_list = param.reshape(-1)
else:
local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)
loss_algo = alpha_coef * torch.sum(local_par_list * (-avg_mdl_param + local_grad_vector))
loss = loss_f_i + loss_algo
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_norm) # Clip gradients
optimizer.step()
epoch_loss += loss.item() * list(batch_y.size())[0]
if (e+1) % print_per == 0:
epoch_loss /= n_trn
if weight_decay != None:
# Add L2 loss to complete f_i
params = get_mdl_params([model], n_par)
epoch_loss += (alpha_coef+weight_decay)/2 * np.sum(params * params)
print("Epoch %3d, Training Loss: %.4f" %(e+1, epoch_loss))
model.train()
# Freeze model
for params in model.parameters():
params.requires_grad = False
model.eval()
return model
###
def train_fedprox_mdl(model, avg_model_param_, mu, trn_x, trn_y, learning_rate, batch_size, epoch, print_per, weight_decay, dataset_name):
n_trn = trn_x.shape[0]
trn_gen = data.DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), batch_size=batch_size, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model.train(); model = model.to(device)
n_par = len(avg_model_param_)
for e in range(epoch):
# Training
epoch_loss = 0
trn_gen_iter = trn_gen.__iter__()
for i in range(int(np.ceil(n_trn/batch_size))):
batch_x, batch_y = trn_gen_iter.__next__()
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
y_pred = model(batch_x)
## Get f_i estimate
loss_f_i = loss_fn(y_pred, batch_y.reshape(-1).long())
loss_f_i = loss_f_i / list(batch_y.size())[0]
# Get linear penalty on the current parameter estimates
local_par_list = None
for param in model.parameters():
if not isinstance(local_par_list, torch.Tensor):
# Initially nothing to concatenate
local_par_list = param.reshape(-1)
else:
local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)
loss_algo = mu/2 * torch.sum(local_par_list * local_par_list)
loss_algo += -mu * torch.sum(local_par_list * avg_model_param_)
loss = loss_f_i + loss_algo
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_norm) # Clip gradients
optimizer.step()
epoch_loss += loss.item() * list(batch_y.size())[0]
if (e+1) % print_per == 0:
epoch_loss /= n_trn
if weight_decay != None:
# Add L2 loss to complete f_i
params = get_mdl_params([model], n_par)
epoch_loss += weight_decay/2 * np.sum(params * params)
print("Epoch %3d, Training Loss: %.4f" %(e+1, epoch_loss))
model.train()
# Freeze model
for params in model.parameters():
params.requires_grad = False
model.eval()
return model