-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevofed_parallel.py
175 lines (143 loc) · 7.75 KB
/
evofed_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import chex
import jax
import jax.numpy as jnp # JAX NumPy
import numpy as np
import tensorflow_datasets as tfds # TFDS for MNIST
import wandb
from evosax import NetworkMapper
from backprop import sl
from args import get_args
from utils import helpers, evo
from evosax import NetworkMapper, ParameterReshaper, FitnessShaper
from flax.core import FrozenDict
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# cosine distance
def cosine(x, y):
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(x ** 2)))
def cosine2(x, y):
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(y ** 2)))
# l2 distance
def l2(x, y):
return -1 * jnp.sqrt(jnp.sum((x - y) ** 2))
def l1(x, y):
return -1 * jnp.sum(jnp.abs(x - y))
def pnorm(x, y, p):
x = jnp.abs(x - y)
return -1 * jnp.sum(x ** p) ** (1 / p)
def max_dist(x, y):
return -1 * 0.02 * jnp.max(jnp.abs(x - y)) + 0.98 * l2(x, y)
# def l2_std(x, y):
# return l2(x, y) +
def sparsify(array, percentage):
original = array
array = jnp.abs(array.flatten())
array = jnp.sort(array)
threshold = array[int(len(array) * percentage)]
array = jnp.where(jnp.abs(original) < threshold, 0, original)
return array
def quantize(array, min_val, max_val, n_bits):
# max_val = array.max()
# min_val = array.min()
step = (max_val - min_val) / (2 ** n_bits - 1)
array = ((array - min_val) / step).round()
return array
# dequantization array
def dequantize(array, min_val, max_val, n_bits):
step = (max_val - min_val) / (2 ** n_bits - 1)
array = array * step + min_val
return array
def pfun(x):
return x
def vfun(x, y):
return y
num_devices = jax.local_device_count()
class TaskManager:
def __init__(self, rng: chex.PRNGKey, args):
wandb.run.name = '{}-{}-{} b{} c{} s{} p{} r{} q{} -- {}' \
.format(args.dataset, args.algo,
args.dist,
args.batch_size, args.n_clients,
args.seed,
args.percentage,
args.rank_factor,
args.quantize_bits,
wandb.run.id)
wandb.run.save()
# self.train_ds, self.test_ds = sl.get_datasets_non_iid(args.dataset, args.n_clients) \
# if args.dist == 'NON-IID' else sl.get_datasets_iid(args.dataset, args.n_clients)
self.train_ds, self.test_ds = sl.get_fed_datasets_pmap(args.dataset, args.n_clients, 2, args.dist == 'IID')
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
self.learning_rate = wandb.config.lr
self.momentum = wandb.config.momentum
network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)
self.state = sl.create_train_state(init_rng, network, self.learning_rate, self.momentum)
self.param_reshaper = ParameterReshaper(self.state.params, n_devices=1)
self.test_param_reshaper = ParameterReshaper(self.state.params, n_devices=1)
self.strategy, self.es_params = evo.get_strategy_and_params(args.pop_size, self.param_reshaper.total_params,
args)
self.fit_shaper = FitnessShaper(centered_rank=args.centered_rank, z_score=args.z_score,
w_decay=args.w_decay, maximize=args.maximize, rank_factor=args.rank_factor)
server = self.strategy.initialize(init_rng, self.es_params)
self.server = server.replace(mean=self.test_param_reshaper.network_to_flat(self.state.params))
del init_rng # Must not be used anymore.
self.param_count = sum(x.size for x in jax.tree_leaves(self.state.params))
self.num_epochs = wandb.config.n_rounds
self.batch_size = wandb.config.batch_size
self.n_clients = args.n_clients
# min_cut = 10000
# self.X = jnp.array([train['image'][:min_cut] for train in self.train_ds])
# self.y = jnp.array([train['label'][:min_cut] for train in self.train_ds])
self.X = self.train_ds['image']
self.y = self.train_ds['label']
self.args = args
self.n_bits = args.quantize_bits
def run(self, rng: chex.PRNGKey):
self.X = jax.pmap(pfun)(self.X.reshape(num_devices, self.args.n_clients // num_devices, *self.X.shape[1:]))
self.y = jax.pmap(pfun)(self.y.reshape(num_devices, self.args.n_clients // num_devices, *self.y.shape[1:]))
for epoch in range(0, self.num_epochs + 1):
rng, input_rng, rng_ask = jax.random.split(rng, 3)
clients = jax.pmap(vfun, in_axes=(0, None))(jnp.arange(num_devices),
jax.vmap(vfun, in_axes=(0, None))(jnp.arange(self.args.n_clients // num_devices), self.state))
# clients, _, _ = jax.vmap(sl.train_epoch, in_axes=(None, 0, 0, None, None))(self.state,
# self.X,
# self.y,
# self.batch_size, input_rng)
clients, loss, acc = sl.train_epoch_pmap(clients, self.X, self.y, self.batch_size, input_rng)
target_server = jax.pmap(jax.vmap(self.param_reshaper.network_to_flat))(clients.params)
x, self.server = self.strategy.ask(rng_ask, self.server, self.es_params)
fitness = jax.pmap(jax.vmap(jax.vmap(l2, in_axes=(0, None)), in_axes=(None, 0)), in_axes=(None, 0))(x, target_server)
fitness = jax.pmap(jax.vmap(self.fit_shaper.apply, in_axes=(None, 0)), in_axes=(None, 0))(x, fitness)
# fitness = jax.vmap(sparsify, in_axes=(0, None))(fitness, self.args.percentage)
# fitness = jax.vmap(quantize, in_axes=(0, None, None, None))(fitness, -0.5, 0.5, self.n_bits)
fitness = jax.pmap(lambda x: x.mean(0))(fitness).mean(0)
# fitness = dequantize(fitness, -0.5, 0.5, self.n_bits)
# fitness = sparsify(fitness, self.args.percentage)
self.server = self.strategy.tell(x, fitness, self.server, self.es_params)
self.state = self.state.replace(params=FrozenDict(self.test_param_reshaper.reshape_single_net(self.server.mean)))
rng, eval_rng = jax.random.split(rng)
test_loss, test_accuracy = sl.eval_model(self.state.params, self.test_ds, eval_rng)
wandb.log({
'Round': epoch,
'Test Loss': test_loss,
'Global Accuracy': test_accuracy,
# 'Communication': epoch * 2 * self.args.pop_size,
# 'Communication': epoch * 2 * self.args.pop_size * (1 - self.args.percentage) * (1 + np.log2(self.args.pop_size)),
# 'Communication': epoch * 4 * self.args.pop_size * (1 - self.args.percentage) * (np.log2(self.args.pop_size * np.sqrt((1 - self.args.percentage) * 1/self.args.rank_factor))),
'Communication': epoch * 2 * self.args.pop_size * (1 - self.args.percentage) * ((self.n_bits + np.log2(self.args.pop_size))/ 32),
})
def run():
print(jax.devices())
args = get_args()
config = helpers.load_config(args.config)
wandb.init(project='evofed-publish', config=args)
wandb.config.update(config)
args = wandb.config
rng = jax.random.PRNGKey(args.seed)
rng, rng_init, rng_run = jax.random.split(rng, 3)
manager = TaskManager(rng_init, args)
manager.run(rng_run)
if __name__ == '__main__':
run()
# wandb.agent('y1lh8ou0', function=run, project='evofed', count=10)