-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
146 lines (133 loc) · 4.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn as nn
import copy
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import sys
def running_uefl_avg(current, next, scale):
"""
compute the average of the model parameters, except for the codebooks
"""
if current == None:
current = next
for key in current:
if 'discretizer' in key:
continue
current[key] = current[key] * scale
else:
for key in current:
if 'discretizer' in key:
continue
current[key] = current[key] + (next[key] * scale)
return current
def validate(test_loader, model, device, args, net_idx):
"""
validate model on test set
"""
model = model.to(device)
model.eval()
criterion = nn.CrossEntropyLoss()
correct, total = 0.0, 0.0
test_loss, test_vqloss, test_ppl = 0.0, 0.0, 0.0
if args.data == 'cifar100':
prediction = np.empty((0, 100))
elif args.data == 'gtsrb':
prediction = np.empty((0, 43))
else:
prediction = np.empty((0, 10))
with torch.no_grad():
for xte, yte in test_loader:
xte = xte.to(device)
pte, vqloss, ppl = model(xte, net_idx)
test_vqloss += vqloss.item()
test_ppl += ppl.item()
lte = criterion(pte.cpu(), yte)
prediction = np.append(prediction, F.softmax(pte, dim=1).cpu(), axis=0)
cls = torch.argmax(pte.cpu(), axis=1)
correct += torch.eq(cls, yte.cpu()).sum().item()
total += xte.shape[0]
test_loss += lte.item()
return test_loss/len(test_loader), correct/total, prediction, test_vqloss/len(test_loader), test_ppl/len(test_loader)
def silo_training(train_loader, test_loader, model, device, args, lr, net_idx, init_and_ext=False):
"""
local training for each silo
"""
localmodel = copy.deepcopy(model)
localmodel = localmodel.to(device)
optimizer = optim.Adam(localmodel.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
loss_tr = []
train_loss = 0.0
if init_and_ext:
localmodel.init_codebooks(train_loader, net_idx, device) # locally initialize the codebooks with kmeans
for e in range(args.epoch):
localmodel.train()
for xtr, ytr in train_loader:
xtr, ytr = xtr.to(device), ytr.to(device)
optimizer.zero_grad()
ptr, train_vqloss, train_ppl = localmodel(xtr, net_idx, ext=init_and_ext)
init_and_ext = False # reset extension flag after extending the codebooks, extend only once for each silo
ltr = criterion(ptr, ytr)+train_vqloss
ltr.backward(retain_graph=True)
optimizer.step()
train_loss += ltr.item()
test_loss, acc, pred, vqloss, ppl = validate(test_loader, localmodel, device, args, net_idx)
loss_tr.append(test_loss/len(test_loader))
train_loss = 0.0
return localmodel, test_loss, acc, vqloss, ppl
def plot_lc(data, t, savepath):
"""
plot learning curve for different metrics
"""
x = np.arange(t)+1
data = np.asarray(data).T
num_silo = data.shape[0]
if num_silo == 5:
plt.figure()
plt.plot(x, data[0], label='silo 1a')
plt.plot(x, data[1], label='silo 1b')
plt.plot(x, data[2], label='silo 1c')
plt.plot(x, data[3], label='silo 2a')
plt.plot(x, data[4], label='silo 3a')
plt.legend()
plt.xlabel('round')
plt.ylabel(savepath.split('_')[-1])
folder = savepath.split('/')[2]
ttl = folder.split('_')[0] # figure title
plt.title(ttl.upper())
plt.savefig(savepath)
plt.close()
elif num_silo == 9:
plt.figure()
plt.plot(x, data[0], label='silo 1a')
plt.plot(x, data[1], label='silo 1b')
plt.plot(x, data[2], label='silo 1c')
plt.plot(x, data[3], label='silo 2a')
plt.plot(x, data[4], label='silo 2b')
plt.plot(x, data[5], label='silo 2c')
plt.plot(x, data[6], label='silo 3a')
plt.plot(x, data[7], label='silo 3b')
plt.plot(x, data[8], label='silo 3c')
plt.legend()
plt.xlabel('round')
plt.ylabel(savepath.split('_')[-1])
folder = savepath.split('/')[2]
ttl = folder.split('_')[0] # figure title
plt.title(ttl.upper())
plt.savefig(savepath)
plt.close()
def plot_metrics(data_list, t, savepath_list):
"""
plot learning curve for all metrics
"""
for i in range(len(data_list)):
plot_lc(data_list[i], t, savepath_list[i])
def entropy(preds):
"""
compute entropy based on predictions
"""
epsilon = sys.float_info.min
entropy = -np.sum(np.mean(preds, axis=0)*np.log(np.mean(preds, axis=0)+epsilon), axis=-1)
return entropy