Skip to content

Commit

Permalink
Revert "lint dev (#23)"
Browse files Browse the repository at this point in the history
This reverts commit e25e0a3.
  • Loading branch information
gautamjajoo authored Sep 13, 2024
1 parent 6745a28 commit 465184a
Show file tree
Hide file tree
Showing 26 changed files with 323 additions and 522 deletions.
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"python.REPL.enableREPLSmartSend": false,
"python.analysis.typeCheckingMode": "strict"
}
88 changes: 34 additions & 54 deletions src/algos/DisPFL.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
"""This module defines the DisPFL optimization strategy for distributed machine learning."""
"""
This module defines the DisPFLClient and DisPFLServer classes for distributed personalized federated learning.
"""

import copy
import math
import random
from collections import OrderedDict
from typing import Any, Dict, List

import numpy as np
import torch
from torch import Tensor, from_numpy, randperm, zeros_like # pylint: disable=no-name-in-module
from torch import nn
import torch.nn as nn
from torch import Tensor, from_numpy, numel, randperm, zeros_like

from algos.base_class import BaseClient, BaseServer


class CommProtocol:
"""
Communication protocol tags for the server and clients.
"""
# pylint: disable=too-few-public-methods

DONE = 0 # Used to signal the server that the client is done with local training
START = 1 # Used to signal by the server to start the current round
Expand All @@ -29,18 +34,11 @@ class DisPFLClient(BaseClient):
"""
Client class for DisPFL (Distributed Personalized Federated Learning).
"""

def __init__(self, config) -> None:
super().__init__(config)
self.params = None
self.mask = None
self.index = None
self.repr = None
self.config = config
self.tag = CommProtocol
self.model_save_path = (
f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
)
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.dense_ratio = self.config["dense_ratio"]
self.anneal_factor = self.config["anneal_factor"]
self.dis_gradient_check = self.config["dis_gradient_check"]
Expand All @@ -66,7 +64,7 @@ def local_test(self, **kwargs):
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
# save the model if the accuracy is better than the best accuracy so far
# TODO save the model if the accuracy is better than the best accuracy so far
# if acc > self.best_acc:
# self.best_acc = acc
# self.model_utils.save_model(self.model, self.model_save_path)
Expand Down Expand Up @@ -96,8 +94,7 @@ def fire_mask(self, masks, round_num, total_round):
"""
weights = self.get_representation()
drop_ratio = (
self.anneal_factor / 2 *
(1 + np.cos((round_num * np.pi) / total_round))
self.anneal_factor / 2 * (1 + np.cos((round_num * np.pi) / total_round))
)
new_masks = copy.deepcopy(masks)
num_remove = {}
Expand All @@ -109,7 +106,7 @@ def fire_mask(self, masks, round_num, total_round):
torch.abs(weights[name]),
100000 * torch.ones_like(weights[name]),
)
_, idx = torch.sort(temp_weights.view(-1).to(self.device))
x, idx = torch.sort(temp_weights.view(-1).to(self.device))
new_masks[name].view(-1)[idx[: num_remove[name]]] = 0
return new_masks, num_remove

Expand Down Expand Up @@ -147,8 +144,7 @@ def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd):
"""
count_mask = copy.deepcopy(masks_lstrnd[self.index])
for k in count_mask.keys():
count_mask[k] = count_mask[k] - \
count_mask[k] # zero out by pruning
count_mask[k] = count_mask[k] - count_mask[k] # zero out by pruning
for clnt in nei_indexes:
count_mask[k] += masks_lstrnd[clnt][k].to(self.device) # mask
for k in count_mask.keys():
Expand Down Expand Up @@ -180,8 +176,7 @@ def send_representations(self, representation):
Set the model.
"""
for client_node in self.clients:
self.comm_utils.send_signal(
client_node, representation, self.tag.UPDATES)
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5):
Expand Down Expand Up @@ -221,8 +216,7 @@ def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5
else:
rhs += n_ones
raw_probabilities[name] = (
np.sum(params[name].shape) /
np.prod(params[name].shape)
np.sum(params[name].shape) / np.prod(params[name].shape)
) ** self.config["erk_power_scale"]
divisor += raw_probabilities[name] * n_param
epsilon = rhs / divisor
Expand All @@ -232,8 +226,7 @@ def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5
is_epsilon_valid = False
for mask_name, mask_raw_prob in raw_probabilities.items():
if mask_raw_prob == max_prob:
print(
f"Sparsity of var:{mask_name} had to be set to 0.")
print(f"Sparsity of var:{mask_name} had to be set to 0.")
dense_layers.add(mask_name)
else:
is_epsilon_valid = True
Expand Down Expand Up @@ -288,20 +281,19 @@ def hamming_distance(self, mask_a, mask_b):
total = 0
for key in mask_a:
dis += torch.sum(
mask_a[key].int().to(
self.device) ^ mask_b[key].int().to(self.device)
mask_a[key].int().to(self.device) ^ mask_b[key].int().to(self.device)
)
total += mask_a[key].numel()
return dis, total

def _benefit_choose(
self,
round_idx, # pylint: disable=unused-argument
round_idx,
cur_clnt,
client_num_in_total,
client_num_per_round,
dist_local, # pylint: disable=unused-argument
total_dist, # pylint: disable=unused-argument
dist_local,
total_dist,
cs=False,
active_ths_rnd=None,
):
Expand Down Expand Up @@ -339,8 +331,7 @@ def model_difference(self, model_a, model_b):
Calculate the difference between two models.
"""
diff = sum(
[torch.sum(torch.square(model_a[name] - model_b[name]))
for name in model_a]
[torch.sum(torch.square(model_a[name] - model_b[name])) for name in model_a]
)
return diff

Expand All @@ -355,7 +346,7 @@ def run_protocol(self):
self.params, sparse=self.dense_ratio
) # calculate sparsity to create masks
self.mask = self.init_masks(self.params, sparsities) # mask_per_local
dist_locals = np.zeros(shape=self.num_users)
dist_locals = np.zeros(shape=(self.num_users))
self.index = self.node_id - 1
masks_lstrnd = [self.mask for i in range(self.num_users)]
weights_lstrnd = [
Expand All @@ -366,8 +357,7 @@ def run_protocol(self):
]
for epoch in range(start_epochs, total_epochs):
# wait for signal to start round
active_ths_rnd = self.comm_utils.wait_for_signal(
src=0, tag=self.tag.START)
active_ths_rnd = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
if epoch != 0:
[weights_lstrnd, masks_lstrnd] = self.comm_utils.wait_for_signal(
src=0, tag=self.tag.LAST_ROUND
Expand Down Expand Up @@ -395,7 +385,8 @@ def run_protocol(self):
if self.num_users != self.config["neighbors"]:
nei_indexes = np.append(nei_indexes, self.index)
print(
f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}")
f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}"
)

for tmp_idx in nei_indexes:
if tmp_idx != self.index:
Expand All @@ -421,8 +412,7 @@ def run_protocol(self):
)
else:
new_repr = copy.deepcopy(weights_lstrnd[self.index])
w_per_globals[self.index] = copy.deepcopy(
weights_lstrnd[self.index])
w_per_globals[self.index] = copy.deepcopy(weights_lstrnd[self.index])
model_diff = self.model_difference(new_repr, self.repr)
print(f"Node {self.node_id} model_diff{model_diff}")
self.comm_utils.send_signal(
Expand All @@ -434,42 +424,34 @@ def run_protocol(self):
# locally train
print(f"Node {self.node_id} local train")
self.local_train()
_, acc = self.local_test()
loss, acc = self.local_test()
print(f"Node {self.node_id} local test: {acc}")
repr = self.get_representation()
if not self.config["static"]:
if not self.dis_gradient_check:
gradient = self.screen_gradient()
self.mask, num_remove = self.fire_mask(
self.mask, epoch, total_epochs)
self.mask, num_remove = self.fire_mask(self.mask, epoch, total_epochs)
self.mask = self.regrow_mask(self.mask, num_remove, gradient)
self.comm_utils.send_signal(
dest=0, data=copy.deepcopy(repr), tag=self.tag.SHARE_WEIGHTS
)

# test updated model
self.set_representation(repr)
_, acc = self.local_test()
loss, acc = self.local_test()
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)


class DisPFLServer(BaseServer):
"""
Server class for DisPFL (Distributed Personalized Federated Learning).
"""

def __init__(self, config) -> None:
super().__init__(config)
self.best_acc = 0
self.round = 0
self.masks = 0
self.reprs = 0
self.config = config
self.set_model_parameters(config)
self.tag = CommProtocol
self.model_save_path = (
f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
)
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.dense_ratio = self.config["dense_ratio"]
self.num_users = self.config["num_users"]

Expand All @@ -484,8 +466,7 @@ def send_representations(self, representations):
Set the model.
"""
for client_node in self.users:
self.comm_utils.send_signal(
client_node, representations, self.tag.UPDATES)
self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES)
self.log_utils.log_console(
f"Server sent {len(representations)} representations to node {client_node}"
)
Expand All @@ -497,7 +478,7 @@ def test(self) -> float:
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
# save the model if the accuracy is better than the best accuracy so far
# TODO save the model if the accuracy is better than the best accuracy so far
if acc > self.best_acc:
self.best_acc = acc
self.model_utils.save_model(self.model, self.model_save_path)
Expand Down Expand Up @@ -553,6 +534,5 @@ def run_protocol(self):

self.single_round(epoch, active_ths_rnd)

accs = self.comm_utils.wait_for_all_clients(
self.users, self.tag.FINISH)
accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH)
self.log_utils.log_console(f"Round {epoch} done; acc {accs}")
40 changes: 9 additions & 31 deletions src/algos/L2C.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Module docstring: This module implements the L2C algorithm for federated learning."""
from collections import OrderedDict, defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from utils.communication.comm_utils import CommunicationManager
import torch
import numpy as np
import torch
from torch import Tensor, cat, tensor, optim
import torch.nn as nn
import torch.nn.functional as F
from torch import optim, Tensor
from torch.nn import Module, Linear, ReLU

from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays
from algos.base_class import BaseFedAvgClient, BaseFedAvgServer
Expand Down Expand Up @@ -102,16 +100,9 @@ def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]])
for key in grad_dict.keys():
if key not in self.model_keys_to_ignore:
if self.sharing_mode == "updates":
cw_grad -= (
models_update_wts[id][key] *
grad_dict[key].cpu()
).sum()
elif self.sharing_mode == "weights":
cw_grad += (
models_update_wts[id][key] *
grad_dict[key].cpu()
).sum()
cw_grad -= (models_update_wts[id][key] * grad_dict[key].cpu()).sum()
elif self.sharing_mode == "weights":
cw_grad += (models_update_wts[id][key] * grad_dict[key].cpu()).sum()
else:
raise ValueError("Unknown sharing mode")
collab_weights_grads.append(cw_grad)
Expand Down Expand Up @@ -199,15 +190,9 @@ def run_protocol(self) -> None:
self.set_model_weights(new_wts, self.model_keys_to_ignore)
round_stats["test_acc"] = self.local_test()
round_stats["validation_loss"], round_stats["validation_acc"] = self.learn_collab_weights(models_update_wts)
round_stats["validation_loss"], round_stats["validation_acc"] = (
self.learn_collab_weights(models_update_wts)
)
print("node {} weight: {}".format(
self.node_id, self.collab_weights))

# Lower the number of neighbors
print(f"node {self.node_id} weight: {self.collab_weights}")
if round == self.config["T_0"]:

self.filter_out_worse_neighbors(self.config["target_users_after_T_0"])

self.comm_utils.send(dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS)
Expand Down Expand Up @@ -241,14 +226,8 @@ def single_round(self) -> list[dict[str, Any]]:

round_stats: list[dict[str, Any]] = self.comm_utils.all_gather(self.tag.ROUND_STATS)
self.log_utils.log_console("Server received all clients stats")

# Log the round stats on tensorboard
self.log_utils.log_tb_round_stats(
round_stats, ["collab_weights"], self.round)

self.log_utils.log_console(
f"Round test acc {[stats['test_acc'] for stats in round_stats]}"
)
self.log_utils.log_tb_round_stats(round_stats, ["collab_weights"], self.round)
self.log_utils.log_console(f"Round test acc {[stats['test_acc'] for stats in round_stats]}")

return round_stats

Expand All @@ -264,8 +243,7 @@ def run_protocol(self) -> None:
round_stats: list[dict[str, Any]] = self.single_round()
stats.append(round_stats)

stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(
stats)
stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats)
stats_dict["round_step"] = 1
self.log_utils.log_experiments_stats(stats_dict)
self.plot_utils.plot_experiments_stats(stats_dict)
Expand Down
Loading

0 comments on commit 465184a

Please sign in to comment.