-
Notifications
You must be signed in to change notification settings - Fork 2
/
aggregation.py
51 lines (34 loc) · 1.76 KB
/
aggregation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
class AggregationStrategy:
def aggregate(self, clients_diff, clients_data_len):
raise NotImplementedError()
class SimpleAverageStrategy(AggregationStrategy):
def aggregate(self, clients_diff, clients_data_len):
clients_num = len(clients_diff)
weight_accumulator = {}
for name, params in clients_diff[-1].items():
weight_accumulator[name] = torch.zeros_like(params)
for _, client_diff in enumerate(clients_diff):
for name, params in client_diff.items():
weight_accumulator[name].add_(params/clients_num)
return weight_accumulator
class FedAverageStrategy(AggregationStrategy):
def aggregate(self, clients_diff, clients_data_len):
total_len = sum(clients_data_len)
weight_accumulator = {}
for name, params in clients_diff[0].items():
weight_accumulator[name] = torch.zeros_like(params)
for i, client_diff in enumerate(clients_diff):
for name, params in client_diff.items():
weight_accumulator[name] = weight_accumulator[name] + (params * clients_data_len[i] / total_len)
return weight_accumulator
class MaxAverageStrategy(AggregationStrategy):
def aggregate(self, clients_diff, clients_data_len):
weight_accumulator = {}
for name, params in clients_diff[-1].items():
weight_accumulator[name] = torch.zeros_like(params)
for _, client_diff in enumerate(clients_diff):
for name, params in client_diff.items():
weight_accumulator[name] = torch.where(torch.abs(params) >= torch.abs(weight_accumulator[name]),
weight_accumulator[name], params)
return weight_accumulator