-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbackdoor_graph_clf.py
347 lines (290 loc) · 15.7 KB
/
backdoor_graph_clf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import random
import torch
from torch import nn
import json
import os
import time
import numpy as np
import copy
from torch.utils.data import DataLoader
from Graph_level_Models.helpers.config import args_parser
from Graph_level_Models.datasets.gnn_util import transform_dataset, split_dataset
from Graph_level_Models.datasets.TUs import TUsDataset
from Graph_level_Models.nets.TUs_graph_classification.load_net import gnn_model
from Graph_level_Models.helpers.evaluate import gnn_evaluate_accuracy
from Graph_level_Models.defenses.defense import foolsgold
from Graph_level_Models.trainer.workerbase import WorkerBase
from Graph_level_Models.aggregators.aggregation import fed_avg,fed_opt, fed_median, fed_trimmedmean, fed_multi_krum, fed_bulyan
def server_robust_agg(args, grad): ## server aggregation
grad_in = np.array(grad).reshape((args.num_workers, -1)).mean(axis=0)
return grad_in.tolist()
class ClearDenseClient(WorkerBase):
def __init__(self, client_id, model, loss_func, train_iter, attack_iter, test_iter, config, optimizer, device,
grad_stub, args, scheduler):
super(ClearDenseClient, self).__init__(model=model, loss_func=loss_func, train_iter=train_iter,
attack_iter=attack_iter, test_iter=test_iter, config=config,
optimizer=optimizer, device=device)
self.client_id = client_id
self.grad_stub = None
self.args = args
self.scheduler = scheduler
def update(self):
pass
class DotDict(dict):
def __init__(self, **kwds):
self.update(kwds)
self.__dict__ = self
def main(args, logger):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
with open(args.config) as f:
config = json.load(f)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(args.device_id)
dataset = TUsDataset(args)
args.device = device
collate = dataset.collate
MODEL_NAME = config['model']
net_params = config['net_params']
if MODEL_NAME in ['GCN', 'GAT']:
if net_params['self_loop']:
print("[!] Adding graph self-loops for GCN/GAT models (central node trick).")
dataset._add_self_loops()
net_params['in_dim'] = dataset.all.graph_lists[0].ndata['feat'][0].shape[0]
num_classes = torch.max(dataset.all.graph_labels).item() + 1
net_params['n_classes'] = num_classes
net_params['dropout'] = args.dropout
args.epoch_backdoor = int(args.epoch_backdoor * args.epochs)
model = gnn_model(MODEL_NAME, net_params)
global_model = gnn_model(MODEL_NAME, net_params)
global_model = global_model.to(device)
client = []
# logger data
loss_func = nn.CrossEntropyLoss()
# Load data
partition, avg_nodes = split_dataset(args, dataset)
drop_last = True if MODEL_NAME == 'DiffPool' else False
triggers = []
all_workers_clean_test_list = []
for i in range(args.num_workers):
local_model = copy.deepcopy(model)
local_model = local_model.to(device)
optimizer = torch.optim.Adam(local_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=args.step_size, gamma=args.gamma)
train_dataset = partition[i]
test_dataset = partition[args.num_workers + i]
print("Client %d training data num: %d" % (i, len(train_dataset)))
print("Client %d testing data num: %d" % (i, len(test_dataset)))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
drop_last=drop_last,
collate_fn=dataset.collate)
attack_loader = None
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
drop_last=drop_last,
collate_fn=dataset.collate)
all_workers_clean_test_list.append(test_loader)
client.append(ClearDenseClient(client_id=i, model=local_model, loss_func=loss_func, train_iter=train_loader,
attack_iter=attack_loader, test_iter=test_loader, config=config,
optimizer=optimizer, device=device, grad_stub=None, args=args,
scheduler=scheduler))
# check model memory address
for i in range(args.num_workers):
add_m = id(client[i].model)
add_o = id(client[i].optimizer)
print('model {} address: {}'.format(i, add_m))
print('optimizer {} address: {}'.format(i, add_o))
# prepare backdoor local backdoor dataset
train_loader_list = []
attack_loader_list = []
test_clean_loader_list = []
test_unchanged_loader_list = []
for i in range(args.num_mali):
train_trigger_graphs, test_trigger_graphs, G_trigger, final_idx, test_clean_data, test_unchanged_data = transform_dataset(partition[i], partition[args.num_workers+i],
avg_nodes, args)
#triggers.append(G_trigger)
tmp_graphs = [partition[i][idx] for idx in range(len(partition[i])) if idx not in final_idx]
train_dataset = train_trigger_graphs + tmp_graphs
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
drop_last=drop_last,
collate_fn=dataset.collate)
# only trigger data in test dataloader
attack_loader = DataLoader(test_trigger_graphs, batch_size=args.batch_size, shuffle=False,
drop_last=drop_last,
collate_fn=dataset.collate)
# only clean data in test dataloader
test_clean_loader = DataLoader(test_clean_data, batch_size=args.batch_size, shuffle=False,
drop_last=drop_last,
collate_fn=dataset.collate)
# only unchanged data in test dataloader
test_unchanged_loader = DataLoader(test_unchanged_data, batch_size=args.batch_size, shuffle=False,
drop_last=drop_last,
collate_fn=dataset.collate)
train_loader_list.append(train_loader)
attack_loader_list.append(attack_loader)
test_clean_loader_list.append(test_clean_loader)
test_unchanged_loader_list.append(test_unchanged_loader)
weight_history = []
for epoch in range(args.epochs):
print('epoch:', epoch)
# worker results
worker_results = {}
for i in range(args.num_workers):
worker_results[f"client_{i}"] = {"train_loss": None, "train_acc": None, "test_loss": None, "test_acc": None}
if epoch >= args.epoch_backdoor:
# malicious clients start backdoor attack
for i in range(0, args.num_mali):
client[i].train_iter = train_loader_list[i]
client[i].attack_iter = attack_loader_list[i]
train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
different_clients_test_accuracy_local_trigger = []
for i in range(args.num_workers):
att_list = []
train_loss, train_acc, test_loss, test_acc = client[i].gnn_train(global_model,args)
different_clients_test_accuracy_local_trigger.append(test_acc)
client[i].scheduler.step()
print('Client %d, loss %.4f, train acc %.3f, test loss %.4f, test acc %.3f'
% (i, train_loss, train_acc, test_loss, test_acc))
# save worker results
for ele in worker_results[f"client_{i}"]:
if ele == "train_loss":
worker_results[f"client_{i}"][ele] = train_loss
elif ele == "train_acc":
worker_results[f"client_{i}"][ele] = train_acc
elif ele == "test_loss":
worker_results[f"client_{i}"][ele] = test_loss
elif ele == "test_acc":
worker_results[f"client_{i}"][ele] = test_acc
for j in range(len(triggers)):
tmp_acc = gnn_evaluate_accuracy(attack_loader_list[j], client[i].model)
print('Client %d with local trigger %d: %.3f' % (i, j, tmp_acc))
att_list.append(tmp_acc)
# wandb logger
logger.log(worker_results)
selected_clients = random.sample(client, args.num_selected_models)
# if there is a defense applied
if args.defense == 'foolsgold':
weights = []
for i in range(args.num_workers):
weights.append(client[i].get_weights())
weight_history.append(client[i].get_weights())
result, weight_history, alpha = foolsgold(args, weight_history, weights)
for i in range(args.num_workers):
client[i].set_weights(weights=result)
client[i].upgrade()
elif args.defense == 'fedavg':
global_model = fed_avg(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fedopt':
global_model = fed_opt(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fedprox':
global_model = fed_avg(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fed_median':
global_model = fed_median(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fed_trimmedmean':
global_model = fed_trimmedmean(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fed_multi_krum':
global_model = fed_multi_krum(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fed_krum':
global_model = fed_multi_krum(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
elif args.defense == 'fed_bulyan':
global_model = fed_bulyan(global_model,selected_clients, args)
# send to local model
for param_tensor in global_model.state_dict():
global_para = global_model.state_dict()[param_tensor]
for local_client in client:
local_client.model.state_dict()[param_tensor].copy_(global_para)
else:
weights = []
for i in range(args.num_workers):
weights.append(client[i].get_weights())
weight_history.append(client[i].get_weights())
result, weight_history, alpha = foolsgold(args, weight_history, weights)
result = server_robust_agg(args, weights)
for i in range(args.num_workers):
client[i].set_weights(weights=result)
client[i].upgrade()
# evaluate the global model: test_acc
test_acc = gnn_evaluate_accuracy(client[0].test_iter, client[0].model)
print('Global Test Acc: %.3f' % test_acc)
# inject triggers into the testing data
if args.num_mali > 0 and epoch >= args.epoch_backdoor:
local_att_acc = []
for i in range(args.num_mali):
tmp_acc = gnn_evaluate_accuracy(attack_loader_list[i], client[0].model)
print('Global model with local trigger %d: %.3f' % (i, tmp_acc))
local_att_acc.append(tmp_acc)
# clean accuracy , poison accuracy, attack success rate
# average all the workers
all_clean_acc_list = []
for i in range(args.num_workers):
tmp_acc = gnn_evaluate_accuracy(all_workers_clean_test_list[i], client[i].model)
print('Client %d with clean accuracy: %.3f' % (i, tmp_acc))
all_clean_acc_list.append(tmp_acc)
average_all_clean_acc = np.mean(np.array(all_clean_acc_list))
local_attack_success_rate_list = []
for i in range(args.num_mali):
tmp_acc = gnn_evaluate_accuracy(attack_loader_list[i], client[i].model)
print('Malicious client %d with local trigger, attack success rate: %.4f' % (i, tmp_acc))
local_attack_success_rate_list.append(tmp_acc)
average_local_attack_success_rate_acc = np.mean(np.array(local_attack_success_rate_list))
local_clean_acc_list = []
for i in range(args.num_mali):
tmp_acc = gnn_evaluate_accuracy(test_clean_loader_list[i], client[i].model)
print('Malicious client %d with clean data, clean accuracy: %.4f' % (i, tmp_acc))
local_clean_acc_list.append(tmp_acc)
average_local_clean_acc = np.mean(np.array(local_clean_acc_list))
average_local_unchanged_acc = 0
# local_unchanged_acc_list = []
# for i in range(args.num_mali):
# tmp_acc = gnn_evaluate_accuracy(test_unchanged_loader_list[i], client[i].model)
# print('Malicious client %d with unchanged data, the unchanged clean accuracy: %.3f' % (i, tmp_acc))
# local_unchanged_acc_list.append(tmp_acc)
# average_local_unchanged_acc = np.mean(np.array(local_unchanged_acc_list))
transfer_attack_success_rate_list = []
if args.num_workers-args.num_mali <= 0:
average_transfer_attack_success_rate = -10000.0
else:
for i in range(args.num_mali):
for j in range(args.num_workers - args.num_mali):
tmp_acc = gnn_evaluate_accuracy(attack_loader_list[i], client[args.num_mali+j].model)
print('Clean client %d with trigger %d: %.3f' % (args.num_mali+j, i, tmp_acc))
transfer_attack_success_rate_list.append(tmp_acc)
average_transfer_attack_success_rate = np.mean(np.array(transfer_attack_success_rate_list))
return average_all_clean_acc, average_local_attack_success_rate_acc, average_local_clean_acc,average_local_unchanged_acc, average_transfer_attack_success_rate
if __name__ == '__main__':
main()