From e4cf4e5932c031ccac661b554c9582601975f609 Mon Sep 17 00:00:00 2001 From: gautamjajoo Date: Tue, 8 Oct 2024 12:40:15 -0700 Subject: [PATCH] fix formatting issues --- src/algos/DisPFL.py | 51 ++- src/algos/L2C.py | 93 +++-- src/algos/MetaL2C.py | 14 +- src/algos/base_class.py | 62 ++- src/algos/def_kt.py | 46 ++- src/algos/fedfomo.py | 21 +- src/algos/fl.py | 84 ++-- src/algos/fl_assigned.py | 142 ++++--- src/algos/fl_central.py | 167 ++++---- src/algos/fl_data_repr.py | 801 ++++++++++++++++++++++++-------------- src/algos/fl_grid.py | 13 +- src/algos/fl_isolated.py | 14 +- src/algos/fl_random.py | 1 + src/algos/fl_static.py | 116 ++++-- src/algos/fl_torus.py | 14 +- src/algos/fl_val.py | 4 +- src/algos/fl_weight.py | 22 +- src/algos/generator.py | 18 +- src/algos/isolated.py | 2 +- src/algos/swarm.py | 17 +- 20 files changed, 1110 insertions(+), 592 deletions(-) diff --git a/src/algos/DisPFL.py b/src/algos/DisPFL.py index 310eb92..dbc1197 100644 --- a/src/algos/DisPFL.py +++ b/src/algos/DisPFL.py @@ -4,7 +4,6 @@ import copy import math -import random from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple @@ -21,7 +20,9 @@ class CommProtocol: Communication protocol tags for the server and clients. """ - DONE: int = 0 # Used to signal the server that the client is done with local training + DONE: int = ( + 0 # Used to signal the server that the client is done with local training + ) START: int = 1 # Used to signal by the server to start the current round UPDATES: int = 2 # Used to send the updates from the server to the clients LAST_ROUND: int = 3 @@ -92,7 +93,9 @@ def set_representation(self, representation: OrderedDict[str, Tensor]) -> None: """ self.model.load_state_dict(representation) - def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]: + def fire_mask( + self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int + ) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]: """ Fire mask method for model pruning. """ @@ -114,7 +117,12 @@ def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round new_masks[name].view(-1)[idx[: num_remove[name]]] = 0 return new_masks, num_remove - def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int], gradient: Optional[Dict[str, Tensor]] = None) -> OrderedDict[str, Tensor]: + def regrow_mask( + self, + masks: OrderedDict[str, Tensor], + num_remove: Dict[str, int], + gradient: Optional[Dict[str, Tensor]] = None, + ) -> OrderedDict[str, Tensor]: """ Regrow mask method for model pruning. """ @@ -142,7 +150,12 @@ def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int new_masks[name].view(-1)[idx] = 1 return new_masks - def aggregate(self, nei_indexes: List[int], weights_lstrnd: List[OrderedDict[str, Tensor]], masks_lstrnd: List[OrderedDict[str, Tensor]]) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]: + def aggregate( + self, + nei_indexes: List[int], + weights_lstrnd: List[OrderedDict[str, Tensor]], + masks_lstrnd: List[OrderedDict[str, Tensor]], + ) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]: """ Aggregate the model weights. """ @@ -183,7 +196,13 @@ def send_representations(self, representation: OrderedDict[str, Tensor]) -> None self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) print(f"Node 1 sent average weight to {len(self.clients)} nodes") - def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[str]] = None, distribution: str = "ERK", sparse: float = 0.5) -> Dict[str, float]: + def calculate_sparsities( + self, + params: Dict[str, Tensor], + tabu: Optional[List[str]] = None, + distribution: str = "ERK", + sparse: float = 0.5, + ) -> Dict[str, float]: """ Calculate sparsities for model pruning. """ @@ -243,7 +262,9 @@ def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[st sparsities[name] = 1 - epsilon * raw_probabilities[name] return sparsities - def init_masks(self, params: Dict[str, Tensor], sparsities: Dict[str, float]) -> OrderedDict[str, Tensor]: + def init_masks( + self, params: Dict[str, Tensor], sparsities: Dict[str, float] + ) -> OrderedDict[str, Tensor]: """ Initialize masks for model pruning. """ @@ -277,7 +298,9 @@ def screen_gradient(self) -> Dict[str, Tensor]: return gradient - def hamming_distance(self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor]) -> Tuple[int, int]: + def hamming_distance( + self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor] + ) -> Tuple[int, int]: """ Calculate the Hamming distance between two masks. """ @@ -330,7 +353,9 @@ def _benefit_choose( ) return client_indexes - def model_difference(self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor]) -> Tensor: + def model_difference( + self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor] + ) -> Tensor: """ Calculate the difference between two models. """ @@ -388,9 +413,7 @@ def run_protocol(self) -> None: ) if self.num_users != self.config["neighbors"]: nei_indexes = np.append(nei_indexes, self.index) - print( - f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}" - ) + print(f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}") for tmp_idx in nei_indexes: if tmp_idx != self.index: @@ -472,7 +495,9 @@ def get_representation(self) -> OrderedDict[str, Tensor]: """ return self.model.state_dict() - def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None: + def send_representations( + self, representations: Dict[int, OrderedDict[str, Tensor]] + ) -> None: """ Set the model. """ diff --git a/src/algos/L2C.py b/src/algos/L2C.py index eba9a5a..7b6579a 100644 --- a/src/algos/L2C.py +++ b/src/algos/L2C.py @@ -1,5 +1,5 @@ -from collections import OrderedDict, defaultdict -from typing import Any, Optional, Union +from collections import defaultdict +from typing import Any from utils.communication.comm_utils import CommunicationManager import torch import numpy as np @@ -12,7 +12,9 @@ class L2CClient(BaseFedAvgClient): - def __init__(self, config: dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.init_collab_weights() self.sharing_mode: str = self.config["sharing"] @@ -79,7 +81,9 @@ def filter_out_worse_neighbors(self, num_neighbors_to_keep: int) -> None: weight_decay=self.config["alpha_weight_decay"], ) - def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]]) -> tuple[float, float]: + def learn_collab_weights( + self, models_update_wts: dict[int, dict[str, Tensor]] + ) -> tuple[float, float]: self.model.eval() alpha_loss: float = 0 correct: int = 0 @@ -92,7 +96,9 @@ def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]]) loss = self.loss_fn(output, target) loss.backward() - grad_dict: dict[str, Tensor] = {k: v.grad for k, v in self.model.named_parameters()} + grad_dict: dict[str, Tensor] = { + k: v.grad for k, v in self.model.named_parameters() + } collab_weights_grads: list[Tensor] = [] for id in self.neighbors_id_to_idx.keys(): @@ -100,9 +106,13 @@ def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]]) for key in grad_dict.keys(): if key not in self.model_keys_to_ignore: if self.sharing_mode == "updates": - cw_grad -= (models_update_wts[id][key] * grad_dict[key].cpu()).sum() + cw_grad -= ( + models_update_wts[id][key] * grad_dict[key].cpu() + ).sum() elif self.sharing_mode == "weights": - cw_grad += (models_update_wts[id][key] * grad_dict[key].cpu()).sum() + cw_grad += ( + models_update_wts[id][key] * grad_dict[key].cpu() + ).sum() else: raise ValueError("Unknown sharing mode") collab_weights_grads.append(cw_grad) @@ -125,9 +135,13 @@ def print_GPU_memory(self) -> None: r: int = torch.cuda.memory_reserved(0) a: int = torch.cuda.memory_allocated(0) f: int = r - a # free inside reserved - print(f"Client {self.node_id} :GPU memory: reserved {r}, allocated {a}, free {f}") + print( + f"Client {self.node_id} :GPU memory: reserved {r}, allocated {a}, free {f}" + ) - def get_collaborator_weights(self, reprs_dict: dict[int, Tensor]) -> dict[int, float]: + def get_collaborator_weights( + self, reprs_dict: dict[int, Tensor] + ) -> dict[int, float]: """ Returns the weights of the collaborators for the current round. """ @@ -170,14 +184,24 @@ def run_protocol(self) -> None: round_stats: dict[str, Any] = {"collab_weights": cw} self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) - round_stats["train_loss"], round_stats["train_acc"] = self.local_train(epochs_per_round) + round_stats["train_loss"], round_stats["train_acc"] = self.local_train( + epochs_per_round + ) repr: dict[str, Tensor] = self.get_representation() - self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) + self.comm_utils.send( + dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT + ) - reprs: list[dict[str, Tensor]] = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPRS_SHARE) - reprs_dict: dict[int, dict[str, Tensor]] = {k: v for k, v in enumerate(reprs, 1)} + reprs: list[dict[str, Tensor]] = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.REPRS_SHARE + ) + reprs_dict: dict[int, dict[str, Tensor]] = { + k: v for k, v in enumerate(reprs, 1) + } - collab_weights_dict: dict[int, float] = self.get_collaborator_weights(reprs_dict) + collab_weights_dict: dict[int, float] = self.get_collaborator_weights( + reprs_dict + ) models_update_wts: dict[int, dict[str, Tensor]] = reprs_dict new_wts: dict[str, Tensor] = self.weighted_aggregate( @@ -185,31 +209,43 @@ def run_protocol(self) -> None: ) if self.sharing_mode == "updates": - new_wts = self.model_utils.substract_model_weights(self.prev_model, new_wts) + new_wts = self.model_utils.substract_model_weights( + self.prev_model, new_wts + ) self.set_model_weights(new_wts, self.model_keys_to_ignore) round_stats["test_acc"] = self.local_test() - round_stats["validation_loss"], round_stats["validation_acc"] = self.learn_collab_weights(models_update_wts) + round_stats["validation_loss"], round_stats["validation_acc"] = ( + self.learn_collab_weights(models_update_wts) + ) print(f"node {self.node_id} weight: {self.collab_weights}") if round == self.config["T_0"]: self.filter_out_worse_neighbors(self.config["target_users_after_T_0"]) - self.comm_utils.send(dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS) + self.comm_utils.send( + dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS + ) class L2CServer(BaseFedAvgServer): - def __init__(self, config: dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.set_model_parameters(config) - self.model_save_path: str = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path: str = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) def test(self) -> float: """ Test the model on the server """ - test_loss, acc = self.model_utils.test(self.model, self._test_loader, self.loss_fn, self.device) + test_loss, acc = self.model_utils.test( + self.model, self._test_loader, self.loss_fn, self.device + ) return acc def single_round(self) -> list[dict[str, Any]]: @@ -218,16 +254,24 @@ def single_round(self) -> list[dict[str, Any]]: """ for client_node in self.users: self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") + self.log_utils.log_console( + "Server waiting for all clients to finish local training" + ) - reprs: list[dict[str, Tensor]] = self.comm_utils.all_gather(self.tag.REPR_ADVERT) + reprs: list[dict[str, Tensor]] = self.comm_utils.all_gather( + self.tag.REPR_ADVERT + ) self.log_utils.log_console("Server received all clients models") self.send_representations(reprs) - round_stats: list[dict[str, Any]] = self.comm_utils.all_gather(self.tag.ROUND_STATS) + round_stats: list[dict[str, Any]] = self.comm_utils.all_gather( + self.tag.ROUND_STATS + ) self.log_utils.log_console("Server received all clients stats") self.log_utils.log_tb_round_stats(round_stats, ["collab_weights"], self.round) - self.log_utils.log_console(f"Round test acc {[stats['test_acc'] for stats in round_stats]}") + self.log_utils.log_console( + f"Round test acc {[stats['test_acc'] for stats in round_stats]}" + ) return round_stats @@ -247,4 +291,3 @@ def run_protocol(self) -> None: stats_dict["round_step"] = 1 self.log_utils.log_experiments_stats(stats_dict) self.plot_utils.plot_experiments_stats(stats_dict) - diff --git a/src/algos/MetaL2C.py b/src/algos/MetaL2C.py index a36c525..3b06fad 100644 --- a/src/algos/MetaL2C.py +++ b/src/algos/MetaL2C.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager import math import torch @@ -77,7 +77,9 @@ def forward(self, model_dict): class MetaL2CClient(BaseFedAvgClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.encoder = ModelEncoder(self.get_model_weights()) @@ -242,9 +244,7 @@ def run_protocol(self): ) # Collect the representations from all other nodes from the server - reprs = self.comm_utils.receive( - self.server_node, tag=self.tag.REPRS_SHARE - ) + reprs = self.comm_utils.receive(self.server_node, tag=self.tag.REPRS_SHARE) reprs_dict = {k: rep for k, (rep, _) in enumerate(reprs, 1)} # Aggregate the representations based on the collab wheigts @@ -296,7 +296,9 @@ def run_protocol(self): class MetaL2CServer(BaseFedAvgServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 4df105b..75645ca 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -28,24 +28,27 @@ get_dset_balanced_communities, get_dset_communities, ) -import torchvision.transforms as T # type: ignore +import torchvision.transforms as T # type: ignore import os from yolo import YOLOLoss + class BaseNode(ABC): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: self.comm_utils = comm_utils self.node_id = self.comm_utils.get_rank() if self.node_id == 0: - self.log_dir = config['log_path'] - config['log_path'] = f'{self.log_dir}/server' + self.log_dir = config["log_path"] + config["log_path"] = f"{self.log_dir}/server" try: - os.mkdir(config['log_path']) + os.mkdir(config["log_path"]) except FileExistsError: pass - config['load_existing'] = False + config["load_existing"] = False self.log_utils = LogUtils(config) self.log_utils.log_console("Config: {}".format(config)) self.plot_utils = PlotUtils(config) @@ -107,7 +110,7 @@ def set_model_parameters(self, config: Dict[str, Any]) -> None: lr=config["model_lr"], weight_decay=config.get("weight_decay", 0), ) - if config.get('dset') == "pascal": + if config.get("dset") == "pascal": self.loss_fn = YOLOLoss() else: self.loss_fn = torch.nn.CrossEntropyLoss() @@ -158,7 +161,9 @@ class BaseClient(BaseNode): Abstract class for all algorithms """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.server_node = 0 self.set_parameters(config) @@ -382,7 +387,9 @@ def local_test(self, **kwargs: Any) -> float | Tuple[float, float] | None: """ raise NotImplementedError - def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor: + def get_representation( + self, **kwargs: Any + ) -> OrderedDict[str, Tensor] | List[Tensor] | Tensor: """ Share the model representation """ @@ -391,7 +398,9 @@ def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor] | List[T def run_protocol(self) -> None: raise NotImplementedError - def print_data_summary(self, train_test: Any, test_dset: Any, val_dset: Optional[Any] = None) -> None: + def print_data_summary( + self, train_test: Any, test_dset: Any, val_dset: Optional[Any] = None + ) -> None: """ Print the data summary """ @@ -435,7 +444,9 @@ class BaseServer(BaseNode): Abstract class for orchestrator """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.num_users = config["num_users"] self.users = list(range(1, self.num_users + 1)) @@ -446,7 +457,9 @@ def set_data_parameters(self, config: Dict[str, Any]) -> None: batch_size = config["batch_size"] self._test_loader = DataLoader(test_dset, batch_size=batch_size) - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: + def aggregate( + self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any + ) -> OrderedDict[str, Tensor]: """ Aggregate the knowledge from the users """ @@ -485,7 +498,13 @@ class BaseFedAvgClient(BaseClient): """ Abstract class for FedAvg based algorithms """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol] = CommProtocol) -> None: + + def __init__( + self, + config: Dict[str, Any], + comm_utils: CommunicationManager, + comm_protocol: type[CommProtocol] = CommProtocol, + ) -> None: super().__init__(config, comm_utils) self.config = config self.model_save_path = "{}/saved_models/node_{}.pt".format( @@ -541,7 +560,9 @@ def get_model_weights(self) -> OrderedDict[str, Tensor]: """ return {k: v.cpu() for k, v in self.model.state_dict().items()} - def set_model_weights(self, model_wts: OrderedDict[str, Tensor], keys_to_ignore: List[str] = []) -> None: + def set_model_weights( + self, model_wts: OrderedDict[str, Tensor], keys_to_ignore: List[str] = [] + ) -> None: """ Set the model weights """ @@ -659,10 +680,17 @@ class BaseFedAvgServer(BaseServer): """ Abstract class for orchestrator """ - - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol] = CommProtocol) -> None: + + def __init__( + self, + config: Dict[str, Any], + comm_utils: CommunicationManager, + comm_protocol: type[CommProtocol] = CommProtocol, + ) -> None: super().__init__(config, comm_utils) self.tag = comm_protocol - def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]): + def send_representations( + self, representations: Dict[int, OrderedDict[str, Tensor]] + ): self.comm_utils.broadcast(representations) diff --git a/src/algos/def_kt.py b/src/algos/def_kt.py index e721965..4680723 100644 --- a/src/algos/def_kt.py +++ b/src/algos/def_kt.py @@ -19,7 +19,9 @@ class CommProtocol: Communication protocol tags for the server and clients """ - DONE: int = 0 # Used to signal the server that the client is done with local training + DONE: int = ( + 0 # Used to signal the server that the client is done with local training + ) START: int = 1 # Used to signal by the server to start the current round UPDATES: int = 2 # Used to send the updates from the server to the clients FINISH: int = 3 # Used to signal the server to finish the current round @@ -29,11 +31,16 @@ class DefKTClient(BaseClient): """ Client class for DefKT (Deep Mutual Learning with Knowledge Transfer) """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) self.server_node = 1 # leader node self.best_acc = 0.0 # Initialize best accuracy attribute if self.node_id == 1: @@ -83,7 +90,9 @@ def set_representation(self, representation: OrderedDict[str, Tensor]) -> None: """ self.model.load_state_dict(representation) - def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]) -> OrderedDict[str, Tensor]: + def fed_avg( + self, model_wts: List[OrderedDict[str, Tensor]] + ) -> OrderedDict[str, Tensor]: """ Federated averaging of model weights """ @@ -101,7 +110,9 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]) -> OrderedDict[str, avgd_wts[key] += coeff * local_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]) -> OrderedDict[str, Tensor]: + def aggregate( + self, representation_list: List[OrderedDict[str, Tensor]] + ) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ @@ -116,7 +127,9 @@ def send_representations(self, representation: OrderedDict[str, Tensor]) -> None self.comm_utils.send(client_node, representation, tag=self.tag.UPDATES) print(f"Node 1 sent average weight to {len(self.clients)} nodes") - def single_round(self, self_repr: OrderedDict[str, Tensor]) -> OrderedDict[str, Tensor]: + def single_round( + self, self_repr: OrderedDict[str, Tensor] + ) -> OrderedDict[str, Tensor]: """ Runs a single training round """ @@ -157,11 +170,15 @@ def run_protocol(self) -> None: if self.status == "teacher": self.local_train() self_repr = self.get_representation() - self.comm_utils.send(dest=self.pair_id, data=self_repr, tag=self.tag.DONE) + self.comm_utils.send( + dest=self.pair_id, data=self_repr, tag=self.tag.DONE + ) print(f"Node {self.node_id} sent repr to student node {self.pair_id}") elif self.status == "student": teacher_repr = self.comm_utils.receive(self.pair_id, tag=self.tag.DONE) - print(f"Node {self.node_id} received repr from teacher node {self.pair_id}") + print( + f"Node {self.node_id} received repr from teacher node {self.pair_id}" + ) self.deep_mutual_train(teacher_repr) else: print(f"Node {self.node_id} do nothing") @@ -174,15 +191,22 @@ class DefKTServer(BaseServer): """ Server class for DefKT (Deep Mutual Learning with Knowledge Transfer) """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) self.best_acc = 0.0 # Initialize best accuracy attribute - def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None: + def send_representations( + self, representations: Dict[int, OrderedDict[str, Tensor]] + ) -> None: """ Send the model representations to the clients """ diff --git a/src/algos/fedfomo.py b/src/algos/fedfomo.py index bd106cf..18176ba 100644 --- a/src/algos/fedfomo.py +++ b/src/algos/fedfomo.py @@ -3,7 +3,6 @@ """ from collections import OrderedDict -from typing import List from torch import Tensor import torch import torch.nn as nn @@ -38,7 +37,9 @@ def __init__(self, config) -> None: super().__init__(config) self.config = config self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) self.dense_ratio = self.config["dense_ratio"] self.anneal_factor = self.config["anneal_factor"] self.dis_gradient_check = self.config["dis_gradient_check"] @@ -124,9 +125,7 @@ def regrow_mask(self, masks, num_remove, gradient=None): torch.abs(gradient[name]).to(self.device), -100000 * torch.ones_like(gradient[name]).to(self.device), ) - _, idx = torch.sort( - temp.view(-1).to(self.device), descending=True - ) + _, idx = torch.sort(temp.view(-1).to(self.device), descending=True) new_masks[name].view(-1)[idx[: num_remove[name]]] = 1 else: temp = torch.where( @@ -199,7 +198,9 @@ def screen_gradient(self): log_probs = model.forward(x) loss = criterion(log_probs, labels.long()) loss.backward() - gradient = {name: param.grad.to("cpu") for name, param in self.model.named_parameters()} + gradient = { + name: param.grad.to("cpu") for name, param in self.model.named_parameters() + } return gradient def hamming_distance(self, mask_a, mask_b): @@ -325,9 +326,7 @@ def run_protocol(self): if self.num_users != self.config["neighbors"]: nei_indexs = np.append(nei_indexs, self.index) nei_indexs = np.sort(nei_indexs) - print( - f"Node {self.index}'s neighbors index: {[i + 1 for i in nei_indexs]}" - ) + print(f"Node {self.index}'s neighbors index: {[i + 1 for i in nei_indexs]}") weights_locals = self.update_weight( self.index, @@ -364,7 +363,9 @@ def __init__(self, config) -> None: self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) self.dense_ratio = self.config["dense_ratio"] self.num_users = self.config["num_users"] self.reprs = None diff --git a/src/algos/fl.py b/src/algos/fl.py index a8fc2d7..7e677ce 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -8,22 +8,27 @@ import os import time + class FedAvgClient(BaseClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config try: - config['log_path'] = f"{config['log_path']}/node_{self.node_id}" - os.makedirs(config['log_path']) + config["log_path"] = f"{config['log_path']}/node_{self.node_id}" + os.makedirs(config["log_path"]) except FileExistsError: - color_code = "\033[91m" # Red color - reset_code = "\033[0m" # Reset to default color - print(f"{color_code}Log directory for the node {self.node_id} already exists in {config['log_path']}") + color_code = "\033[91m" # Red color + reset_code = "\033[0m" # Reset to default color + print( + f"{color_code}Log directory for the node {self.node_id} already exists in {config['log_path']}" + ) print(f"Exiting to prevent accidental overwrite{reset_code}") sys.exit(1) - config['load_existing'] = False + config["load_existing"] = False self.client_log_utils = LogUtils(config) def local_train(self, round: int, **kwargs: Any): @@ -38,24 +43,33 @@ def local_train(self, round: int, **kwargs: Any): time_taken = end_time - start_time self.client_log_utils.log_console( - "Client {} finished training with loss {:.4f}, accuracy {:.4f}, time taken {:.2f} seconds".format(self.node_id, avg_loss, avg_accuracy, time_taken) + "Client {} finished training with loss {:.4f}, accuracy {:.4f}, time taken {:.2f} seconds".format( + self.node_id, avg_loss, avg_accuracy, time_taken + ) + ) + self.client_log_utils.log_summary( + "Client {} finished training with loss {:.4f}, accuracy {:.4f}, time taken {:.2f} seconds".format( + self.node_id, avg_loss, avg_accuracy, time_taken ) - self.client_log_utils.log_summary("Client {} finished training with loss {:.4f}, accuracy {:.4f}, time taken {:.2f} seconds".format(self.node_id, avg_loss, avg_accuracy, time_taken)) + ) - self.client_log_utils.log_tb(f"train_loss/client{self.node_id}", avg_loss, round) - self.client_log_utils.log_tb(f"train_accuracy/client{self.node_id}", avg_accuracy, round) + self.client_log_utils.log_tb( + f"train_loss/client{self.node_id}", avg_loss, round + ) + self.client_log_utils.log_tb( + f"train_accuracy/client{self.node_id}", avg_accuracy, round + ) def local_test(self, **kwargs: Any): """ Test the model locally, not to be used in the traditional FedAvg """ - pass def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]: """ Share the model weights """ - return self.model.state_dict() # type: ignore + return self.model.state_dict() # type: ignore def set_representation(self, representation: OrderedDict[str, Tensor]): """ @@ -71,18 +85,32 @@ def run_protocol(self): self.local_train(round) self.local_test() repr = self.get_representation() - - self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node)) + + self.client_log_utils.log_summary( + "Client {} sending done signal to {}".format( + self.node_id, self.server_node + ) + ) self.comm_utils.send(self.server_node, repr) - self.client_log_utils.log_summary("Client {} waiting to get new model from {}".format(self.node_id, self.server_node)) + self.client_log_utils.log_summary( + "Client {} waiting to get new model from {}".format( + self.node_id, self.server_node + ) + ) repr = self.comm_utils.receive(self.server_node) - self.client_log_utils.log_summary("Client {} received new model from {}".format(self.node_id, self.server_node)) + self.client_log_utils.log_summary( + "Client {} received new model from {}".format( + self.node_id, self.server_node + ) + ) self.set_representation(repr) # self.client_log_utils.log_summary("Round {} done for Client {}".format(round, self.node_id)) class FedAvgServer(BaseServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config @@ -119,14 +147,16 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): avgd_wts: OrderedDict[str, Tensor] = OrderedDict() for key in model_wts[0].keys(): - avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore + avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore # Move to GPU only after averaging for key in avgd_wts.keys(): avgd_wts[key] = avgd_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any) -> OrderedDict[str, Tensor]: + def aggregate( + self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any + ) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ @@ -168,7 +198,7 @@ def single_round(self): # self.log_utils.log_console("Server received all clients done signal") avg_wts = self.aggregate(reprs) self.set_representation(avg_wts) - #Remove the signal file after confirming that all client paths have been created + # Remove the signal file after confirming that all client paths have been created if os.path.exists(self.folder_deletion_signal): os.remove(self.folder_deletion_signal) @@ -184,7 +214,15 @@ def run_protocol(self): loss, acc, time_taken = self.test() self.log_utils.log_tb(f"test_acc/clients", acc, round) self.log_utils.log_tb(f"test_loss/clients", loss, round) - self.log_utils.log_console("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) + self.log_utils.log_console( + "Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format( + round, acc, loss, time_taken + ) + ) # self.log_utils.log_summary("Round: {} test_acc:{:.4f}, test_loss:{:.4f}, time taken {:.2f} seconds".format(round, acc, loss, time_taken)) self.log_utils.log_console("Round {} complete".format(round)) - self.log_utils.log_summary("Round {} complete".format(round,)) + self.log_utils.log_summary( + "Round {} complete".format( + round, + ) + ) diff --git a/src/algos/fl_assigned.py b/src/algos/fl_assigned.py index d0afa3a..0c6261c 100644 --- a/src/algos/fl_assigned.py +++ b/src/algos/fl_assigned.py @@ -1,6 +1,6 @@ import numpy as np import math -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager from algos.base_class import BaseFedAvgClient, BaseFedAvgServer @@ -8,33 +8,44 @@ class FedAssClient(BaseFedAvgClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils=comm_utils) - + def get_collaborator_weights(self, num_collaborator, round): """ Returns the weights of the collaborators for the current round - """ - if self.config["strategy"] =="fixed": - collab_weights = {id: 1 for id in self.config["assigned_collaborators"][self.node_id]} - elif self.config["strategy"] =="direct_expo": - power = round % math.floor(math.log2(self.config["num_users"]-1)) - steps = math.pow(2,power) - collab_id = int(((self.node_id + steps) % self.config["num_users"]) + 1) + """ + if self.config["strategy"] == "fixed": + collab_weights = { + id: 1 for id in self.config["assigned_collaborators"][self.node_id] + } + elif self.config["strategy"] == "direct_expo": + power = round % math.floor(math.log2(self.config["num_users"] - 1)) + steps = math.pow(2, power) + collab_id = int(((self.node_id + steps) % self.config["num_users"]) + 1) collab_weights = {self.node_id: 1, collab_id: 1} - elif self.config["strategy"] =="random_among_assigned": - collab_weights = {k: 1 for k in np.random.choice(list(self.config["assigned_collaborators"][self.node_id]), size=num_collaborator, replace=False)} + elif self.config["strategy"] == "random_among_assigned": + collab_weights = { + k: 1 + for k in np.random.choice( + list(self.config["assigned_collaborators"][self.node_id]), + size=num_collaborator, + replace=False, + ) + } collab_weights[self.node_id] = 1 else: raise ValueError("Strategy not implemented") - + total = sum(collab_weights.values()) collab_weights = {id: w / total for id, w in collab_weights.items()} return collab_weights - + def get_representation(self): return self.get_model_weights() - + def run_protocol(self): print(f"Client {self.node_id} ready to start training") start_round = self.config.get("start_round", 0) @@ -42,69 +53,82 @@ def run_protocol(self): epochs_per_round = self.config["epochs_per_round"] for round in range(start_round, total_rounds): stats = {} - + # Wait on server to start the round self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) - + repr = self.get_representation() - self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) + self.comm_utils.send( + dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT + ) + + # Collect the representations from all other nodes from the server + reprs = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.REPRS_SHARE + ) - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPRS_SHARE) - # In the future this dict might be generated by the server to send only requested models - reprs_dict = {k:v for k,v in enumerate(reprs, 1)} - - num_collaborator = self.config[f"target_users_{'before' if round < self.config['T_0'] else 'after'}_T_0"] + reprs_dict = {k: v for k, v in enumerate(reprs, 1)} + + num_collaborator = self.config[ + f"target_users_{'before' if round < self.config['T_0'] else 'after'}_T_0" + ] # Aggregate the representations based on the collab weights collab_weights_dict = self.get_collaborator_weights(num_collaborator, round) - + collaborators = [k for k, w in collab_weights_dict.items() if w > 0] # If there are no collaborators, then the client does not update its model - if not (len(collaborators)==1 and collaborators[0]==self.node_id): + if not (len(collaborators) == 1 and collaborators[0] == self.node_id): # Since clients representations are also used to transmit knowledge # There is no need to fetch the server for the selected clients' knowledge models_wts = reprs_dict - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=self.model_keys_to_ignore) - + + avg_wts = self.weighted_aggregate( + models_wts, + collab_weights_dict, + keys_to_ignore=self.model_keys_to_ignore, + ) + # Average whole model by default self.set_model_weights(avg_wts, self.model_keys_to_ignore) - + stats["test_acc_before_training"] = self.local_test() - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - - # Test updated model + # Test updated model stats["test_acc_after_training"] = self.local_test() # Include collab weights in the stats collab_weight = np.zeros(self.config["num_users"]) - for k,v in collab_weights_dict.items(): - collab_weight[k-1] = v + for k, v in collab_weights_dict.items(): + collab_weight[k - 1] = v stats["collab_weights"] = collab_weight - self.comm_utils.send(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) + self.comm_utils.send( + dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS + ) + class FedAssServer(BaseFedAvgServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils=comm_utils) # self.set_parameters() self.config = config self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) - + self.model_save_path = "{}/saved_models/node_{}.pt".format( + self.config["results_path"], self.node_id + ) + def test(self) -> float: """ Test the model on the server """ - test_loss, acc = self.model_utils.test(self.model, - self._test_loader, - self.loss_fn, - self.device) + test_loss, acc = self.model_utils.test( + self.model, self._test_loader, self.loss_fn, self.device + ) # TODO save the model if the accuracy is better than the best accuracy so far if acc > self.best_acc: self.best_acc = acc @@ -115,28 +139,34 @@ def single_round(self): """ Runs the whole training procedure """ - + # Send signal to all clients to start local training for client_node in self.users: self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") - + self.log_utils.log_console( + "Server waiting for all clients to finish local training" + ) + # Collect models from all clients models = self.comm_utils.all_gather(self.tag.REPR_ADVERT) - self.log_utils.log_console("Server received all clients models") - + self.log_utils.log_console("Server received all clients models") + # Broadcast the models to all clients self.send_representations(models) - + # Collect round stats from all clients - round_stats = self.comm_utils.all_gather(self.tag.ROUND_STATS) - self.log_utils.log_console("Server received all clients stats") + round_stats = self.comm_utils.all_gather(self.tag.ROUND_STATS) + self.log_utils.log_console("Server received all clients stats") # Log the round stats on tensorboard except the collab weights self.log_utils.log_tb_round_stats(round_stats, ["collab_weights"], self.round) - self.log_utils.log_console(f"Round acc TALT {[stats['test_acc_after_training'] for stats in round_stats]}") - self.log_utils.log_console(f"Round acc TBLT {[stats['test_acc_before_training'] for stats in round_stats]}") + self.log_utils.log_console( + f"Round acc TALT {[stats['test_acc_after_training'] for stats in round_stats]}" + ) + self.log_utils.log_console( + f"Round acc TBLT {[stats['test_acc_before_training'] for stats in round_stats]}" + ) return round_stats @@ -155,6 +185,6 @@ def run_protocol(self): stats.append(round_stats) stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 + stats_dict["round_step"] = 1 self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) \ No newline at end of file + self.plot_utils.plot_experiments_stats(stats_dict) diff --git a/src/algos/fl_central.py b/src/algos/fl_central.py index 82f9b07..159032e 100644 --- a/src/algos/fl_central.py +++ b/src/algos/fl_central.py @@ -1,7 +1,7 @@ from collections import OrderedDict import torch from torch import Tensor -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager from torch.utils.data import DataLoader, Subset @@ -10,37 +10,43 @@ from algos.base_class import BaseClient, BaseServer from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays + class CommProtocol(object): """ Communication protocol tags for the server and clients """ - SEND_DATA = 0 # Used to signal by the server to start + + SEND_DATA = 0 # Used to signal by the server to start SHARE_DATA = 1 SEND_MODEL = 2 SHARE_MODEL = 3 - ROUND_STATS = 4 # Used to signal the server the client is done + ROUND_STATS = 4 # Used to signal the server the client is done + class CentralizedCLient(BaseClient): - - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) - + self.model_save_path = "{}/saved_models/node_{}.pt".format( + self.config["results_path"], self.node_id + ) + self.central_client = sorted(self.communities[self.node_id])[0] self.is_central_client = self.node_id == self.central_client - + def local_train(self, epochs, data_loader): """ Train the model locally """ avg_loss, avg_acc = 0, 0 for epoch in range(epochs): - tr_loss, tr_acc = self.model_utils.train(self.model, self.optim, - data_loader, self.loss_fn, - self.device) + tr_loss, tr_acc = self.model_utils.train( + self.model, self.optim, data_loader, self.loss_fn, self.device + ) avg_loss += tr_loss avg_acc += tr_acc @@ -48,42 +54,41 @@ def local_train(self, epochs, data_loader): avg_acc /= epochs return avg_loss, avg_acc - + def local_test(self, **kwargs): """ Test the model locally, not to be used in the traditional FedAvg """ - test_loss, acc = self.model_utils.test(self.model, - self._test_loader, - self.loss_fn, - self.device) + test_loss, acc = self.model_utils.test( + self.model, self._test_loader, self.loss_fn, self.device + ) if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) return acc - + def get_model_weights(self) -> OrderedDict[str, Tensor]: """ Share the model weights (on the cpu) """ - return {k: v.cpu() for k, v in self.model.state_dict().items()} + return {k: v.cpu() for k, v in self.model.state_dict().items()} # return {k: v.cpu() for k, v in self.model.module.state_dict().items()} - + def set_model_weights(self, model_wts: OrderedDict[str, Tensor], keys_to_ignore=[]): """ Set the model weights """ model_wts = copy.copy(model_wts) - + if len(keys_to_ignore) > 0: for key in keys_to_ignore: model_wts.pop(key) - + for key in model_wts.keys(): model_wts[key] = model_wts[key].to(self.device) - - self.model.load_state_dict(model_wts, strict= len(keys_to_ignore) == 0) - + + self.model.load_state_dict(model_wts, strict=len(keys_to_ignore) == 0) + def mask_last_layer(self): wts = self.get_model_weights() keys = self.model_utils.get_last_layer_keys(wts) @@ -95,44 +100,61 @@ def mask_last_layer(self): def freeze_model_except_last_layer(self): wts = self.get_model_weights() keys = self.model_utils.get_last_layer_keys(wts) - + for name, param in self.model.module.named_parameters(): if name not in keys: param.requires_grad = False - + def unfreeze_model(self): for param in self.model.module.parameters(): - param.requires_grad = True + param.requires_grad = True + + def run_protocol(self): - def run_protocol(self): - # self.comm_utils.send_signal(dest=self.server_node, data=self.train_indices, tag=self.tag.SEND_DATA) - self.comm_utils.send(dest=self.server_node, data=(self.central_client, self.train_dset), tag=self.tag.SEND_DATA) - + self.comm_utils.send( + dest=self.server_node, + data=(self.central_client, self.train_dset), + tag=self.tag.SEND_DATA, + ) + global_dloader = None if self.is_central_client: - global_dset = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.SHARE_DATA) - #train_dset = Subset(self.dset_obj.train_dset, train_indices) + global_dset = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.SHARE_DATA + ) + # train_dset = Subset(self.dset_obj.train_dset, train_indices) batch_size = self.config["batch_size"] - global_dloader = DataLoader(global_dset, batch_size=batch_size, shuffle=True) + global_dloader = DataLoader( + global_dset, batch_size=batch_size, shuffle=True + ) - start_round = self.config.get("start_round", 0) total_rounds = self.config["rounds"] epochs_per_round = self.config["epochs_per_round"] - + for round in range(start_round, total_rounds): round_stats = {} - + # Train locally if self.is_central_client: - round_stats["train_loss"], round_stats["train_acc"] = self.local_train(epochs_per_round, global_dloader) + round_stats["train_loss"], round_stats["train_acc"] = self.local_train( + epochs_per_round, global_dloader + ) global_model = self.get_model_weights() - self.comm_utils.send(dest=self.server_node, data=(self.communities[self.node_id], global_model), tag=self.tag.SEND_MODEL) - self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.SHARE_MODEL) + self.comm_utils.send( + dest=self.server_node, + data=(self.communities[self.node_id], global_model), + tag=self.tag.SEND_MODEL, + ) + self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.SHARE_MODEL + ) else: round_stats["train_loss"], round_stats["train_acc"] = 0.0, 0.0 - global_model = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.SHARE_MODEL) + global_model = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.SHARE_MODEL + ) self.set_model_weights(global_model) # Test model @@ -142,46 +164,55 @@ def run_protocol(self): # self.freeze_model_except_last_layer() # self.local_train(1, self.dloader) round_stats["test_acc"] = self.local_test() - + # if self.node_id == self.config["central_client"]: # self.set_model_weights(global_model) - # self.unfreeze_model() - + # self.unfreeze_model() + # send stats to server - self.comm_utils.send(dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS) - -# Define a dataset class that takes a list of dataset as input + self.comm_utils.send( + dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS + ) + + +# Define a dataset class that takes a list of dataset as input class CentralizedServer(BaseServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config self.set_model_parameters(config) self.tag = CommProtocol - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) + self.model_save_path = "{}/saved_models/node_{}.pt".format( + self.config["results_path"], self.node_id + ) + def run_protocol(self): self.log_utils.log_console("Starting centralised learning") - + # clients_samples = self.comm_utils.wait_for_all_clients(self.clients, self.tag.SEND_DATA) # samples = set() # for client_samples in clients_samples: # samples.update(client_samples) # samples = list(samples) - + clients_dset = self.comm_utils.all_gather(self.tag.SEND_DATA) - + # Regroup the datasets per central client (one central client per community) dset_per_central_client = {} for central_client, dset in clients_dset: if central_client not in dset_per_central_client: dset_per_central_client[central_client] = [] dset_per_central_client[central_client].append(dset) - + for central_client, dsets in dset_per_central_client.items(): dset = torch.utils.data.ConcatDataset(dsets) - self.comm_utils.send(dest=central_client, data=dset, tag=self.tag.SHARE_DATA) - + self.comm_utils.send( + dest=central_client, data=dset, tag=self.tag.SHARE_DATA + ) + self.log_utils.log_console("Starting random P2P collaboration") start_round = self.config.get("start_round", 0) total_round = self.config["rounds"] @@ -191,20 +222,24 @@ def run_protocol(self): for round in range(start_round, total_round): self.round = round self.log_utils.log_console("Starting round {}".format(round)) - - central_models = self.comm_utils.receive(node_ids=list(dset_per_central_client.keys()), tag=self.tag.SEND_MODEL) + + central_models = self.comm_utils.receive( + node_ids=list(dset_per_central_client.keys()), tag=self.tag.SEND_MODEL + ) for clients, model in central_models: for client in clients: - self.comm_utils.send(dest=client, data=model, tag=self.tag.SHARE_MODEL) - + self.comm_utils.send( + dest=client, data=model, tag=self.tag.SHARE_MODEL + ) + self.log_utils.log_console("Server waiting for all clients to finish") - + round_stats = self.comm_utils.all_gather(tag=self.tag.ROUND_STATS) stats.append(round_stats) - + print(f"Round test acc {[stats['test_acc'] for stats in round_stats]}") - + stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 + stats_dict["round_step"] = 1 self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) \ No newline at end of file + self.plot_utils.plot_experiments_stats(stats_dict) diff --git a/src/algos/fl_data_repr.py b/src/algos/fl_data_repr.py index 1bdc773..5f5f767 100644 --- a/src/algos/fl_data_repr.py +++ b/src/algos/fl_data_repr.py @@ -3,12 +3,13 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager from algos.base_class import BaseFedAvgClient, BaseFedAvgServer from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays from torch.utils.data import DataLoader, Dataset + # from algos.modules import (DeepInversionFeatureHook, kl_loss_fn, kl_loss_pw_fn, # total_variation_loss, DistCorrelation) from sklearn.cluster import AffinityPropagation, MeanShift @@ -33,24 +34,37 @@ class CommProtocol(object): CONS_SHARE = 11 # Server shares consensus with clients -TWO_STEP_STRAT= ["CTAR_KL", "LTLR_KL", "CTLR_KL", - "train_loss_inv", "train_loss_sm", - "euclidean_pairwise_KL", - "dist_corr_AR", "dist_corr_LR", - "log_all_metrics"] +TWO_STEP_STRAT = [ + "CTAR_KL", + "LTLR_KL", + "CTLR_KL", + "train_loss_inv", + "train_loss_sm", + "euclidean_pairwise_KL", + "dist_corr_AR", + "dist_corr_LR", + "log_all_metrics", +] CONS_STEP = ["vote_1hop"] -SIM_SHARING = ["affinity_propagation_clustering", "mean_shift_clustering", - "club","sim_of_sim", "sim_averaging"] +SIM_SHARING = [ + "affinity_propagation_clustering", + "mean_shift_clustering", + "club", + "sim_of_sim", + "sim_averaging", +] class FedDataRepClient(BaseFedAvgClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils=comm_utils, comm_protocol=CommProtocol) self.tag = CommProtocol self.require_second_step = config["similarity_metric"] in TWO_STEP_STRAT - self.with_sim_sharing = self.config.get("consensus", "") in SIM_SHARING + self.with_sim_sharing = self.config.get("consensus", "") in SIM_SHARING if self.config.get("sim_running_average", 0) > 0: self.running_average = {} @@ -69,34 +83,38 @@ def set_reprs_parameters(self): self.lowest_inv_loss_inputs = None def get_dream_reprs(self, num_samples): - - #img_in = torch.randn([1] + self.config["inp_shape"]).to(self.device) - #reprs_shape = [num_samples] + list(self.model.module.get_activation_shape(self.config["reprs_position"], img_in)[1:]) - + + # img_in = torch.randn([1] + self.config["inp_shape"]).to(self.device) + # reprs_shape = [num_samples] + list(self.model.module.get_activation_shape(self.config["reprs_position"], img_in)[1:]) + if self.config["reprs_position"] == 0: reprs_shape = [num_samples] + self.config["inp_shape"] else: pos = self.config["reprs_position"] - reprs_shape = [num_samples] + [max((pos-1) * 64, 64), int(self.config["inp_shape"][1]/2**(max(pos-2, 0))), int(self.config["inp_shape"][2]/2**(max(pos-2, 0)))] - #reprs_shape = [128, 128, 16, 16] - + reprs_shape = [num_samples] + [ + max((pos - 1) * 64, 64), + int(self.config["inp_shape"][1] / 2 ** (max(pos - 2, 0))), + int(self.config["inp_shape"][2] / 2 ** (max(pos - 2, 0))), + ] + # reprs_shape = [128, 128, 16, 16] + inputs = torch.randn(reprs_shape).to(self.device).requires_grad_(True) inputs.retain_grad() opt = torch.optim.Adam([inputs], lr=self.config["inv_lr"]) out = None - + loss_r_feature_layers = [] for module in self.model.modules(): if isinstance(module, nn.BatchNorm2d): loss_r_feature_layers.append(DeepInversionFeatureHook(module)) - + round_lowest_inv_loss = np.inf for epoch in range(self.config["inv_epochs"] + 1): self.model.zero_grad() opt.zero_grad() - - #torch.cuda.empty_cache() + + # torch.cuda.empty_cache() acts = self.model(inputs, position=self.config["reprs_position"]) probs = torch.softmax(acts, dim=1) entropy = -torch.sum(probs * torch.log(probs + self.EPS), dim=1).mean() @@ -109,48 +127,50 @@ def get_dream_reprs(self, num_samples): ) loss = ( self.config["alpha_preds"] * entropy - + self.config["alpha_tv"] * total_variation_loss(inputs)#.to(entropy.device) - + self.config["alpha_l2"] * torch.linalg.norm(inputs)#.to(entropy.device) + + self.config["alpha_tv"] + * total_variation_loss(inputs) # .to(entropy.device) + + self.config["alpha_l2"] + * torch.linalg.norm(inputs) # .to(entropy.device) + self.config["alpha_f"] * loss_r_feature ) - if self.node_id ==1: + if self.node_id == 1: if torch.isnan(loss).any(): print("NaN loss encountered. Exiting training.") break - + loss.backward() opt.step() if self.node_id == 1: if epoch % 100 == 0: print("Epoch {} loss {}".format(epoch, loss.item())) - + if epoch == self.config["inv_epochs"]: out = acts.clone().detach().cpu() - + # Keep track of lowest round loss if loss.item() < round_lowest_inv_loss: round_lowest_inv_loss = loss.item() - + # Keep track of lowest overall loss if loss.item() < self.lowest_inv_loss: self.lowest_inv_loss = loss.item() if self.config.get("dreams_keep_best", False): self.lowest_inv_loss_inputs = inputs.clone().detach() - + for item in loss_r_feature_layers: item.close() - + self.round_stats["lowest_inv_loss"] = round_lowest_inv_loss out = F.log_softmax(out, dim=1) - + # Keep traxk of reprs class distribution class_count = np.zeros(out.shape[1]) temp = np.bincount(np.argmax(out.numpy(), axis=1)) - class_count[:temp.shape[0]] = temp - self.round_stats["inv class distribution"] = class_count + class_count[: temp.shape[0]] = temp + self.round_stats["inv class distribution"] = class_count - if self.config.get("dreams_keep_best", False): + if self.config.get("dreams_keep_best", False): return self.lowest_inv_loss_inputs.cpu(), out else: return inputs.detach().cpu(), out @@ -191,7 +211,9 @@ def get_representation(self): reprs = batch_x if reprs is None else torch.cat((reprs, batch_x), dim=0) out = batch_out if out is None else torch.cat((out, batch_out), dim=0) - reprs_y = labels if reprs_y is None else torch.cat((reprs_y, labels), dim=0) + reprs_y = ( + labels if reprs_y is None else torch.cat((reprs_y, labels), dim=0) + ) reprs = reprs[:num_repr_samples] out = out[:num_repr_samples] @@ -206,17 +228,26 @@ def get_representation(self): def get_second_step_representation(self, reprs_dict): if not self.require_second_step: return None - - if self.config["similarity_metric"] in ["train_loss_inv", "train_loss_sm", - "CTLR_KL", "CTAR_KL", - "dist_corr_AR", "dist_corr_LR", - "log_all_metrics", ]: + + if self.config["similarity_metric"] in [ + "train_loss_inv", + "train_loss_sm", + "CTLR_KL", + "CTAR_KL", + "dist_corr_AR", + "dist_corr_LR", + "log_all_metrics", + ]: return self.compute_softlabels(reprs_dict) elif self.config["similarity_metric"] in ["euclidean_pairwise_KL", "LTLR_KL"]: # LTLR is contained in CTCR from the point of view of the client receiving this representation return self.compute_CTCR_KL(reprs_dict) else: - raise ValueError("Similarity metric {} not implemented".format(self.config["similarity_metric"])) + raise ValueError( + "Similarity metric {} not implemented".format( + self.config["similarity_metric"] + ) + ) def flatten_repr(self, repr): params = [] @@ -230,21 +261,27 @@ def flatten_repr(self, repr): def compute_pseudo_grad_norm(self, prev_wts, new_wts): return np.linalg.norm(self.flatten_repr(prev_wts) - self.flatten_repr(new_wts)) - + # === SIMILARITY === # - def get_collaborators_similarity(self, round, reprs_dict, second_step_reprs_dict=None): + def get_collaborators_similarity( + self, round, reprs_dict, second_step_reprs_dict=None + ): sim_metric = self.config["similarity_metric"] if sim_metric == "CTCR_KL": sim_dict = self.compute_CTCR_KL(reprs_dict) elif sim_metric == "euclidean": sim_dict = self.compute_euclidean(reprs_dict) elif sim_metric.startswith("train_loss"): - sim_dict = self.compute_training_loss_similarity(second_step_reprs_dict, sim_metric) + sim_dict = self.compute_training_loss_similarity( + second_step_reprs_dict, sim_metric + ) elif sim_metric == "euclidean_pairwise_KL": sim_dict = self.compute_euclidan_pairwise_KL(second_step_reprs_dict) elif sim_metric.startswith("euclidean_on"): - sim_dict = self.compute_euclidean_sim_on_top_sim_profile(second_step_reprs_dict) + sim_dict = self.compute_euclidean_sim_on_top_sim_profile( + second_step_reprs_dict + ) elif sim_metric == "CTAR_KL": sim_dict = self.compute_CTAR_KL(second_step_reprs_dict) elif sim_metric == "LTLR_KL": @@ -255,36 +292,46 @@ def get_collaborators_similarity(self, round, reprs_dict, second_step_reprs_dict sim_dict = self.compute_AR_correlation(second_step_reprs_dict) elif sim_metric == "dist_corr_LR": sim_dict = self.compute_LR_correlation(second_step_reprs_dict) - elif sim_metric == "log_all_metrics": # Should be combined + elif sim_metric == "log_all_metrics": # Should be combined self.log_all_metrics(reprs_dict, second_step_reprs_dict) sim_dict = {self.node_id: 0} else: raise ValueError("Similarity metric {} not implemented".format(sim_metric)) - + num_round_avg = self.config.get("sim_running_average", 0) if num_round_avg > 0: - num_round_exclude, num_round_exclude_after_T0 = self.config.get("sim_exclude_first", (0,0)) + num_round_exclude, num_round_exclude_after_T0 = self.config.get( + "sim_exclude_first", (0, 0) + ) t_0 = self.config.get("T_0", None) - exclude_round = round < num_round_exclude or (t_0 and t_0 <= round and round < t_0 + num_round_exclude_after_T0) + exclude_round = round < num_round_exclude or ( + t_0 and t_0 <= round and round < t_0 + num_round_exclude_after_T0 + ) if self.node_id == 1: - print(f"Is round {round} excluded from similarity averaging ? {exclude_round}") + print( + f"Is round {round} excluded from similarity averaging ? {exclude_round}" + ) - for k,v in sim_dict.items(): + for k, v in sim_dict.items(): if not exclude_round: if k not in self.running_average: self.running_average[k] = [v] else: - self.running_average[k] = self.running_average[k][-num_round_avg:] + [v] - - if self.node_id == 1 and k==1: - print(f"Similarity average based on {len(self.running_average[k])} rounds ") + self.running_average[k] = self.running_average[k][ + -num_round_avg: + ] + [v] + + if self.node_id == 1 and k == 1: + print( + f"Similarity average based on {len(self.running_average[k])} rounds " + ) # If not previous round included return 0 sim_dict[k] = np.mean(self.running_average.get(k, [0])) self.log_clients_stats(sim_dict, "Similarity after running average") return sim_dict - + # === First representation step === # def compute_softlabels(self, reprs_dict): @@ -351,9 +398,9 @@ def get_sim_matrix(self, reprs_dict): sim_matrix = np.zeros((len(clients_ids), len(clients_ids))) for k1, v in reprs_dict.items(): for k2, v2 in v.items(): - sim_matrix[k1-1, k2-1] = v2 - return sim_matrix - + sim_matrix[k1 - 1, k2 - 1] = v2 + return sim_matrix + def compute_CTAR_KL(self, second_step_reprs_dict): # Compute the KL divergence between each two clients from their # softlabels on every clients' representations @@ -389,25 +436,31 @@ def compute_CTAR_KL(self, second_step_reprs_dict): return KL_dict def compute_training_loss_similarity(self, second_step_reprs_dict, sim_type): - sl_clients_on_own_data = {id: soft_label_dict[self.node_id] for id, soft_label_dict in second_step_reprs_dict.items()} - loss_dict = {id: self.loss_fn(softlabels, self.reprs_y).item() for id, softlabels in sl_clients_on_own_data.items()} - + sl_clients_on_own_data = { + id: soft_label_dict[self.node_id] + for id, soft_label_dict in second_step_reprs_dict.items() + } + loss_dict = { + id: self.loss_fn(softlabels, self.reprs_y).item() + for id, softlabels in sl_clients_on_own_data.items() + } + self.log_clients_stats(loss_dict, "Train loss LR") - - if sim_type == "train_loss_inv": # Take 1 - loss/total_losses + + if sim_type == "train_loss_inv": # Take 1 - loss/total_losses total = sum(loss_dict.values()) - sim_dict = {id: 1-v/total for id,v in loss_dict.items()} - elif sim_type == "train_loss_sm": # Take softmax of train loss + sim_dict = {id: 1 - v / total for id, v in loss_dict.items()} + elif sim_type == "train_loss_sm": # Take softmax of train loss loss_tensor = torch.zeros(len(loss_dict)) for id, loss in loss_dict.items(): - loss_tensor[id-1] = loss + loss_tensor[id - 1] = loss soft_loss = 1 - F.softmax(loss_tensor, dim=0) - sim_dict = {id : v.item() for id,v in enumerate(soft_loss, 1)} + sim_dict = {id: v.item() for id, v in enumerate(soft_loss, 1)} else: raise ValueError("Similarity type {} not implemented".format(sim_type)) - + self.log_clients_stats(sim_dict, "Train loss similarity") - + return sim_dict def compute_euclidan_pairwise_KL(self, second_step_reprs_dict): @@ -499,7 +552,7 @@ def compute_LR_correlation(self, second_step_reprs_dict): return sim_dict # === Client Selection === # - ''' + """ def select_top_k(self, collab_similarity, k, round, total_rounds): # Remove the nodes that are not in the same community @@ -719,29 +772,41 @@ def select_top_k(self, collab_similarity, k, round, total_rounds): for key in collab_similarity.keys() } return collab_weights, proba_dist - ''' + """ def select_collaborator(self, collab_similarity, k, round, total_rounds): - + # Remove the nodes that are not in the same community - collab_similarity = {key: value for key, value in collab_similarity.items() if key in self.communities[self.node_id]} - - # Similarity <= 0, means no collaboration - collab_similarity = {key: value for key, value in collab_similarity.items() if value > 0 and key != self.node_id} + collab_similarity = { + key: value + for key, value in collab_similarity.items() + if key in self.communities[self.node_id] + } - if k==0: + # Similarity <= 0, means no collaboration + collab_similarity = { + key: value + for key, value in collab_similarity.items() + if value > 0 and key != self.node_id + } + + if k == 0: selected_collab = [self.node_id] proba_dist = {self.node_id: 1} else: strategy = self.config.get("selection_strategy") temp = self.config.get("selection_temperature", 1) - + if strategy == "highest": - sorted_collab = sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True) + sorted_collab = sorted( + collab_similarity.items(), key=lambda item: item[1], reverse=True + ) selected_collab = [key for key, _ in sorted_collab][:k] proba_dist = {key: 1 for key in selected_collab} elif strategy == "lowest": - sorted_collab = sorted(collab_similarity.items(), key=lambda item: item[1], reverse=False) + sorted_collab = sorted( + collab_similarity.items(), key=lambda item: item[1], reverse=False + ) selected_collab = [key for key, _ in sorted_collab][:k] proba_dist = {key: 1 for key in selected_collab} elif strategy.endswith("sim_sampling"): @@ -752,105 +817,185 @@ def select_collaborator(self, collab_similarity, k, round, total_rounds): # temp *= np.exp(-round/total_rounds) if strategy == "sim_sampling": total = sum(collab_similarity.values()) - proba_dist = {key: sim/total for key, sim in collab_similarity.items()} + proba_dist = { + key: sim / total for key, sim in collab_similarity.items() + } elif strategy == "lower_exp_sim_sampling": - proba_dist = {key: np.exp(-value/temp) for key, value in collab_similarity.items()} - elif strategy == "higher_exp_sim_sampling": - proba_dist = {key: np.exp(value/temp) for key, value in collab_similarity.items()} + proba_dist = { + key: np.exp(-value / temp) + for key, value in collab_similarity.items() + } + elif strategy == "higher_exp_sim_sampling": + proba_dist = { + key: np.exp(value / temp) + for key, value in collab_similarity.items() + } elif strategy == "lower_lin_sim_sampling": total = sum(collab_similarity.values()) - proba_dist = {key:1 - value/total for key, value in collab_similarity.items()} - elif strategy == "higher_lin_sim_sampling": + proba_dist = { + key: 1 - value / total + for key, value in collab_similarity.items() + } + elif strategy == "higher_lin_sim_sampling": total = sum(collab_similarity.values()) - proba_dist = {key:value/total for key, value in collab_similarity.items()} + proba_dist = { + key: value / total for key, value in collab_similarity.items() + } else: - raise ValueError("Selection strategy {} not implemented".format(strategy)) + raise ValueError( + "Selection strategy {} not implemented".format(strategy) + ) proba_dist[self.node_id] = 0 total = sum(proba_dist.values()) - proba_dist = {key: value/total for key, value in proba_dist.items()} + proba_dist = {key: value / total for key, value in proba_dist.items()} items = list(proba_dist.items()) - selected_collab = list(np.random.choice([key for key, _ in items], k, p=[value for _, value in items], replace=False)) + selected_collab = list( + np.random.choice( + [key for key, _ in items], + k, + p=[value for _, value in items], + replace=False, + ) + ) elif strategy.endswith("top_x"): - + top_x = self.config.get("num_users_top_x", 0) - + if strategy == "growing_schedulded_top_x": - top_x = 1 + int(top_x * round/total_rounds) + top_x = 1 + int(top_x * round / total_rounds) self.round_stats["top_x"] = top_x # Get the top most similar clients - sorted_collab = [id for id, _ in sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True)] + sorted_collab = [ + id + for id, _ in sorted( + collab_similarity.items(), + key=lambda item: item[1], + reverse=True, + ) + ] top_x_collab = sorted_collab[:top_x] - + proba_dist = {key: 1 for key in top_x_collab} num_collab = min(k, len(top_x_collab)) - selected_collab = list(np.random.choice(top_x_collab, num_collab, replace=False)) + selected_collab = list( + np.random.choice(top_x_collab, num_collab, replace=False) + ) elif strategy == "eps_greedy": eps = self.config.get("eps_greedy", 0.1) if np.random.rand() < eps: - collab = [id for id in range(1, self.config["num_users"]+1) if id != self.node_id] + collab = [ + id + for id in range(1, self.config["num_users"] + 1) + if id != self.node_id + ] selected_collab = list(np.random.choice(collab, k, replace=False)) else: - sorted_collab = [id for id, _ in sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True)] + sorted_collab = [ + id + for id, _ in sorted( + collab_similarity.items(), + key=lambda item: item[1], + reverse=True, + ) + ] selected_collab = sorted_collab[:k] - proba_dist = {key: 1/k for key in selected_collab} + proba_dist = {key: 1 / k for key in selected_collab} elif strategy == "xth": - sorted_collab = [id for id, _ in sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True)] - + sorted_collab = [ + id + for id, _ in sorted( + collab_similarity.items(), + key=lambda item: item[1], + reverse=True, + ) + ] + xth = self.config.get("num_users_top_x", 0) - + if xth >= len(sorted_collab): - raise ValueError("xth {} must be smaller than the number of clients {}".format(xth, len(sorted_collab))) - - selected_collab = [sorted_collab[xth-1]] - proba_dist = {key: 1 if key == selected_collab[0] else 0 for key in sorted_collab} + raise ValueError( + "xth {} must be smaller than the number of clients {}".format( + xth, len(sorted_collab) + ) + ) + + selected_collab = [sorted_collab[xth - 1]] + proba_dist = { + key: 1 if key == selected_collab[0] else 0 for key in sorted_collab + } elif strategy == "uniform_rdm": - selected_collab = list(np.random.choice(list(collab_similarity.keys()), k, replace=False)) - proba_dist = {key: 1/k for key in selected_collab} + selected_collab = list( + np.random.choice(list(collab_similarity.keys()), k, replace=False) + ) + proba_dist = {key: 1 / k for key in selected_collab} else: - raise ValueError("Selection strategy {} not implemented".format(strategy)) - + raise ValueError( + "Selection strategy {} not implemented".format(strategy) + ) + selected_collab.append(self.node_id) - - collab_weights = {key: 1/len(selected_collab) if key in selected_collab else 0 for key in range(1, self.config["num_users"]+1)} + + collab_weights = { + key: 1 / len(selected_collab) if key in selected_collab else 0 + for key in range(1, self.config["num_users"] + 1) + } return collab_weights, proba_dist - + # === Consensus === # def get_consensus_similarity(self, collab_similarity, num_client_to_select): consensus_strategy = self.config.get("consensus", "") - if consensus_strategy =="": + if consensus_strategy == "": return collab_similarity if consensus_strategy == "sim_averaging": collab_similarity = self.weighted_sim_averaging(collab_similarity) elif consensus_strategy.endswith("clustering"): - collab_similarity = self.cluster_points(collab_similarity, num_client_to_select, consensus_strategy) + collab_similarity = self.cluster_points( + collab_similarity, num_client_to_select, consensus_strategy + ) elif consensus_strategy == "sim_of_sim": - collab_similarity = self.compute_euclidean_sim_on_top_sim_profile(collab_similarity) + collab_similarity = self.compute_euclidean_sim_on_top_sim_profile( + collab_similarity + ) elif consensus_strategy == "vote_1hop": collab_similarity = self.vote_consensus(collab_similarity) elif consensus_strategy == "club": collab_similarity = self.club_consensus(collab_similarity) else: - raise ValueError("Consensus strategy {} not implemented".format(consensus_strategy)) - + raise ValueError( + "Consensus strategy {} not implemented".format(consensus_strategy) + ) + self.log_clients_stats(collab_similarity, f"{consensus_strategy} similarity") return collab_similarity - + def weighted_sim_averaging(self, collab_similarity): if self.config["similarity_metric"] == "train_loss_inv": own_dict = collab_similarity[self.node_id] - + # Keep only k highest similarity - top_a_averaging =self.config.get("sim_consensus_top_a", self.config["num_users"]-1) - if top_a_averaging < self.config["num_users"]-1: - sorted_collab = sorted(own_dict.items(), key=lambda item: item[1], reverse=True) - sorted_collab = [(key,v) for key, v in sorted_collab if key != self.node_id][:top_a_averaging] - - filtered_own_weights = {key: value for key, value in own_dict.items() if key in sorted_collab} + top_a_averaging = self.config.get( + "sim_consensus_top_a", self.config["num_users"] - 1 + ) + if top_a_averaging < self.config["num_users"] - 1: + sorted_collab = sorted( + own_dict.items(), key=lambda item: item[1], reverse=True + ) + sorted_collab = [ + (key, v) for key, v in sorted_collab if key != self.node_id + ][:top_a_averaging] + + filtered_own_weights = { + key: value + for key, value in own_dict.items() + if key in sorted_collab + } total = sum(filtered_own_weights.values()) - self.log_clients_stats({key:value/total for key, value in filtered_own_weights.items()}, "Trust weights for consensus") + self.log_clients_stats( + {key: value / total for key, value in filtered_own_weights.items()}, + "Trust weights for consensus", + ) - new_dict = {id: 0 for id in own_dict.keys()} total = 0 for c_id, c_dict in collab_similarity.items(): @@ -861,91 +1006,115 @@ def weighted_sim_averaging(self, collab_similarity): new_dict[c1_id] += c_conf * own_dict[c_id] else: new_dict[c1_id] += c_conf * c1_score - total+=c_conf - collab_similarity = {key: v/total for key,v in new_dict.items()} - - for key,v in collab_similarity.items(): + total += c_conf + collab_similarity = {key: v / total for key, v in new_dict.items()} + + for key, v in collab_similarity.items(): if v > 1: print("Client {} collab {} sim {}".format(self.node_id, key, v)) else: - raise ValueError("Similarity consensus not implemented for {}".format(self.config["similarity_metric"])) - + raise ValueError( + "Similarity consensus not implemented for {}".format( + self.config["similarity_metric"] + ) + ) + return collab_similarity - + def vote_consensus(self, collab_similarity): - - sorted_collab = [id for id, _ in sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True) if id != self.node_id] - + + sorted_collab = [ + id + for id, _ in sorted( + collab_similarity.items(), key=lambda item: item[1], reverse=True + ) + if id != self.node_id + ] + num_voter, num_vote_per_voter = self.config["vote_consensus"] - + voters = sorted_collab[:num_voter] # + 1 to avoid receiving recommendation only about myself - self.comm_utils.send(dest=self.server_node, data=sorted_collab[:num_vote_per_voter+1], tag=self.tag.CONS_ADVERT) - vote_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.CONS_SHARE) - + self.comm_utils.send( + dest=self.server_node, + data=sorted_collab[: num_vote_per_voter + 1], + tag=self.tag.CONS_ADVERT, + ) + vote_dict = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.CONS_SHARE + ) + candidate_list = [] for voter_id, votes in vote_dict.items(): if voter_id in voters: - candidate_list += [v for v in votes if v != self.node_id][:num_vote_per_voter] - + candidate_list += [v for v in votes if v != self.node_id][ + :num_vote_per_voter + ] + candidate_count = {} - for i in candidate_list: + for i in candidate_list: candidate_count[i] = candidate_count.get(i, 0) + 1 - + self.log_clients_stats(candidate_count, "Vote count") - #sorted_collab = [id for id, value in sorted(candidate_count.items(), key=lambda item: item[1], reverse=True) if id != self.node_id] + # sorted_collab = [id for id, value in sorted(candidate_count.items(), key=lambda item: item[1], reverse=True) if id != self.node_id] sum_vote = sum(candidate_count.values()) - collab_similarity = {key: value/sum_vote for key, value in candidate_count.items()} + collab_similarity = { + key: value / sum_vote for key, value in candidate_count.items() + } return collab_similarity - + def club_consensus(self, collab_similarity): num_clients = self.config["num_users"] # Get similarity matrix collab_sim = np.zeros((num_clients, num_clients)) for k1, v in collab_similarity.items(): for k2, v2 in v.items(): - collab_sim[k1-1, k2-1] = v2 - - np.fill_diagonal(collab_sim, -np.inf) # Remove self similarity - sorted_clients = collab_sim.argsort(axis=1)[:,::-1] # Assume max is best - + collab_sim[k1 - 1, k2 - 1] = v2 + + np.fill_diagonal(collab_sim, -np.inf) # Remove self similarity + sorted_clients = collab_sim.argsort(axis=1)[:, ::-1] # Assume max is best + def accept_client(requester, member): - return requester in sorted_clients[member:,:self.config["club_k_accepted"]] - + return ( + requester in sorted_clients[member:, : self.config["club_k_accepted"]] + ) + WEAK_LINK = 0.5 STRONG_LINK = 1 # Compute club graph edges = np.zeros_like(collab_sim) for i in range(num_clients): for j in range(self.config.get("club_top_k", 1)): - if accept_client(i, sorted_clients[i,j]): - edges[i, sorted_clients[i,j]] = STRONG_LINK - edges[sorted_clients[i,j], i] = STRONG_LINK + if accept_client(i, sorted_clients[i, j]): + edges[i, sorted_clients[i, j]] = STRONG_LINK + edges[sorted_clients[i, j], i] = STRONG_LINK else: - edges[i, sorted_clients[i,j]] = WEAK_LINK - - own_idx = self.node_id-1 + edges[i, sorted_clients[i, j]] = WEAK_LINK + + own_idx = self.node_id - 1 self.round_stats["club"] = edges[own_idx] - + # Compute cluster with DFS visited = [False] * num_clients + def dfs(v, cluster): visited[v] = True cluster.append(v) for i in range(num_clients): if edges[v, i] == STRONG_LINK and not visited[i]: dfs(i, cluster) - + own_cluster = [] dfs(own_idx, own_cluster) - # If top 1 is in cluster, sample at random from cluster if sorted_clients[own_idx, 0] in own_cluster: # Assume k+1 clients in cluster (guaranteed with k=1) - return {id+1: 1 if id in own_cluster else 0 for id in collab_similarity.keys()} - + return { + id + 1: 1 if id in own_cluster else 0 for id in collab_similarity.keys() + } + else: own_cluster_wo_self = [id for id in own_cluster if id is not own_idx] # For weak link top link, find path to cluster following weak links @@ -955,13 +1124,13 @@ def dfs(v, cluster): pointing_to = [] dfs(sorted_clients[own_idx, 0], cluster) while len(cluster) == 1 and not all(visited): - next_client = cluster.pop() + next_client = cluster.pop() pointing_to.append(next_client) dfs(sorted_clients[next_client, 0], cluster) - + if cluster[0] not in pointing_to: - pointing_to+= cluster - + pointing_to += cluster + weak_link_strategy = self.config["club_weak_link_strategy"] if weak_link_strategy == "own_cluster_and_pointing_to": potential = own_cluster_wo_self + pointing_to @@ -973,52 +1142,57 @@ def dfs(v, cluster): elif weak_link_strategy == "pointing_to": potential = pointing_to else: - raise ValueError("Weak link strategy {} not implemented".format(weak_link_strategy)) - + raise ValueError( + "Weak link strategy {} not implemented".format(weak_link_strategy) + ) + # Shift index to client id - potential = [id+1 for id in potential] - + potential = [id + 1 for id in potential] + return {id: 1 if id in potential else 0 for id in collab_similarity.keys()} - + def cluster_points(self, collab_similarity, k, method): num_clients = self.config["num_users"] # Get similarity matrix collab_sim = np.zeros((num_clients, num_clients)) - + for k1, v in collab_similarity.items(): for k2, v2 in v.items(): - collab_sim[k1-1, k2-1] = v2 - self_idx = self.node_id-1 - + collab_sim[k1 - 1, k2 - 1] = v2 + self_idx = self.node_id - 1 + if method == "affinity_propagation_clustering": - + np.fill_diagonal(collab_sim, np.median(collab_sim)) - + # Compute clusters - affinity = 'precomputed' if self.config["affinity_precomputed"] else 'euclidean' - clustering = AffinityPropagation(random_state=5, affinity=affinity).fit(collab_sim) + affinity = ( + "precomputed" if self.config["affinity_precomputed"] else "euclidean" + ) + clustering = AffinityPropagation(random_state=5, affinity=affinity).fit( + collab_sim + ) pred = np.array(clustering.labels_) - #center_indices = clustering.cluster_centers_indices_ - + # center_indices = clustering.cluster_centers_indices_ + elif method == "mean_shift_clustering": clustering = MeanShift().fit(collab_sim) pred = np.array(clustering.labels_) else: raise ValueError("Clustering method {} not implemented".format(method)) - + self.round_stats[method] = pred # Get clients in the same cluster own_cluster_idx = pred[self_idx] own_cluster = [idx for idx in list(np.where(pred == own_cluster_idx)[0])] - + own_cluster_wo_self = [id for id in own_cluster if id != self_idx] - - potential_collab = own_cluster_wo_self + potential_collab = own_cluster_wo_self # num_collab = np.round(self.config.get("ap_cluster_collab", 1) * len(own_cluster_wo_self)) - #cluster_leader = [idx for idx in clustering.cluster_centers_indices_ if idx in own_cluster][0] + # cluster_leader = [idx for idx in clustering.cluster_centers_indices_ if idx in own_cluster][0] # Select top num_collab clients to collaborate with # Ranking is based on their similarity with leader @@ -1026,55 +1200,63 @@ def cluster_points(self, collab_similarity, k, method): # cluster_with_leader_sim = [(idx, collab_sim[cluster_leader, idx]) for idx in own_cluster_wo_self] # sorted_collab = [id for id, value in sorted(cluster_with_leader_sim, key=lambda item: item[1], reverse=True)] # potential_collab = sorted_collab[:num_collab] - + # If alone in a cluster, collaborate with closest neighbors instead if len(potential_collab) < 1: - + # +1 to be able to select k clients even if self is in top k sorted_collab = np.argsort(collab_sim[self_idx])[::-1] - potential_collab = (sorted_collab[:k+1]).tolist() + potential_collab = (sorted_collab[: k + 1]).tolist() potential_collab = [id for id in potential_collab if id != self_idx][:k] - print("Client {} has no cluster, found {} closest instead".format(self.node_id, len(potential_collab))) - + print( + "Client {} has no cluster, found {} closest instead".format( + self.node_id, len(potential_collab) + ) + ) + # Shift index to client id - potential_collab = [id+1 for id in potential_collab] - + potential_collab = [id + 1 for id in potential_collab] + # TODO Instead of 1 return similarity with leader/cluster - sim_dict = {id: 1 if id in potential_collab else 0 for id in collab_similarity.keys()} - - - #selected_collab = list(np.random.choice(potential_collab, k, replace=False)) - + sim_dict = { + id: 1 if id in potential_collab else 0 for id in collab_similarity.keys() + } + + # selected_collab = list(np.random.choice(potential_collab, k, replace=False)) + # From matrix idx to client idx - #selected_collab = [id+1 for id in selected_collab] - #proba_dist = {id: 1 for id in selected_collab} - - return sim_dict #, proba_dist - + # selected_collab = [id+1 for id in selected_collab] + # proba_dist = {id: 1 for id in selected_collab} + + return sim_dict # , proba_dist + def compute_euclidean_sim_on_top_sim_profile(self, similarity_profile_per_client): clients_ids = sorted(similarity_profile_per_client[self.node_id].keys()) - - own_idx = self.node_id-1 + + own_idx = self.node_id - 1 sim_matrix = self.get_sim_matrix(similarity_profile_per_client) - - # TODO Might want to remove similarity of other clients to themselves and to self + + # TODO Might want to remove similarity of other clients to themselves and to self sim_matrix -= sim_matrix[own_idx] - euclidean_dist = (sim_matrix ** 2).sum(axis=1) - - sim_dict = {id+1: 1-euclidean_dist[id]/euclidean_dist.sum() for id in range(len(clients_ids))} - - self.log_clients_stats(sim_dict,"Similarity of similarity") + euclidean_dist = (sim_matrix**2).sum(axis=1) + + sim_dict = { + id + 1: 1 - euclidean_dist[id] / euclidean_dist.sum() + for id in range(len(clients_ids)) + } + + self.log_clients_stats(sim_dict, "Similarity of similarity") return sim_dict # === Logging === # - + def log_all_metrics(self, reprs_dict, second_step_reprs_dict): # Computing will automatically log all metrics - self.compute_CTAR_KL(second_step_reprs_dict) - self.compute_training_loss_similarity(second_step_reprs_dict, "train_loss_inv") - self.compute_training_loss_similarity(second_step_reprs_dict, "train_loss_sm") - self.compute_CTLR_KL(second_step_reprs_dict) - self.compute_CTCR_KL(reprs_dict) + self.compute_CTAR_KL(second_step_reprs_dict) + self.compute_training_loss_similarity(second_step_reprs_dict, "train_loss_inv") + self.compute_training_loss_similarity(second_step_reprs_dict, "train_loss_sm") + self.compute_CTLR_KL(second_step_reprs_dict) + self.compute_CTCR_KL(reprs_dict) self.compute_euclidean(reprs_dict) def log_clients_stats(self, client_dict, stat_name): @@ -1092,88 +1274,143 @@ def run_protocol(self): collab_weights_dict = {} for round in range(start_round, total_rounds): self.round_stats = {} - + # Wait on server to start the round self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) - + if round == start_round: warmup_epochs = self.config.get("warmup_epochs", epochs_per_round) if warmup_epochs > 0: warmup_loss, warmup_acc = self.local_train(warmup_epochs) - print("Client {} warmup loss {} acc {}".format(self.node_id, warmup_loss, warmup_acc)) - + print( + "Client {} warmup loss {} acc {}".format( + self.node_id, warmup_loss, warmup_acc + ) + ) + repr = self.get_representation() - self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REP1_ADVERT) - - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPS1_SHARE) - reprs_dict = {k:v for k,v in enumerate(reprs, 1)} - + self.comm_utils.send( + dest=self.server_node, data=repr, tag=self.tag.REP1_ADVERT + ) + + # Collect the representations from all other nodes from the server + reprs = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.REPS1_SHARE + ) + reprs_dict = {k: v for k, v in enumerate(reprs, 1)} + second_step_reprs_dict = None if self.require_second_step: reprs2 = self.get_second_step_representation(reprs_dict) - self.comm_utils.send(dest=self.server_node, data=reprs2, tag=self.tag.REP2_ADVERT) - second_step_reprs_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPS2_SHARE) - - similarity_dict = self.get_collaborators_similarity(round, reprs_dict, second_step_reprs_dict) - + self.comm_utils.send( + dest=self.server_node, data=reprs2, tag=self.tag.REP2_ADVERT + ) + second_step_reprs_dict = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.REPS2_SHARE + ) + + similarity_dict = self.get_collaborators_similarity( + round, reprs_dict, second_step_reprs_dict + ) + if self.with_sim_sharing: - self.comm_utils.send(dest=self.server_node, data=similarity_dict, tag=self.tag.SIM_ADVERT) - similarity_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.SIM_SHARE) - - is_num_collab_changed = round == self.config['T_0'] and self.config["target_users_after_T_0"] != self.config["target_users_before_T_0"] - is_selection_round = round == start_round or (round % self.config["rounds_per_selection"] == 0) or is_num_collab_changed - + self.comm_utils.send( + dest=self.server_node, data=similarity_dict, tag=self.tag.SIM_ADVERT + ) + similarity_dict = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.SIM_SHARE + ) + + is_num_collab_changed = ( + round == self.config["T_0"] + and self.config["target_users_after_T_0"] + != self.config["target_users_before_T_0"] + ) + is_selection_round = ( + round == start_round + or (round % self.config["rounds_per_selection"] == 0) + or is_num_collab_changed + ) + if is_selection_round: num_collaborator_before = self.config["target_users_before_T_0"] num_collaborator_after = self.config["target_users_after_T_0"] - # num_collaborator = self.config[f"target_clients_{'before' if round < self.config['T_0'] else 'after'}_T_0"] - is_before_T0 = round < self.config['T_0'] + # num_collaborator = self.config[f"target_clients_{'before' if round < self.config['T_0'] else 'after'}_T_0"] + is_before_T0 = round < self.config["T_0"] if num_collaborator_before == 0 and is_before_T0: - + similarity_dict = self.get_consensus_similarity(similarity_dict, 0) # Run selection to log mock selection etc (if not must handle change in server communication) - #self.select_collaborator(similarity_dict, num_collaborator_after, round, total_rounds) - collab_weights_dict, proba_dist = {self.node_id: 1}, {self.node_id: 1} + # self.select_collaborator(similarity_dict, num_collaborator_after, round, total_rounds) + collab_weights_dict, proba_dist = {self.node_id: 1}, { + self.node_id: 1 + } else: - num_collab = num_collaborator_before if is_before_T0 else num_collaborator_after - similarity_dict = self.get_consensus_similarity(similarity_dict, num_collab) - collab_weights_dict, proba_dist = self.select_collaborator(similarity_dict, num_collab, round, total_rounds) - + num_collab = ( + num_collaborator_before + if is_before_T0 + else num_collaborator_after + ) + similarity_dict = self.get_consensus_similarity( + similarity_dict, num_collab + ) + collab_weights_dict, proba_dist = self.select_collaborator( + similarity_dict, num_collab, round, total_rounds + ) + self.log_clients_stats(proba_dist, "Selection probability") - - selected_collaborators = [key for key, value in collab_weights_dict.items() if value > 0] - self.comm_utils.send(dest=self.server_node, data=(selected_collaborators, self.get_model_weights()), tag=self.tag.C_SELECTION) - models_wts = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.KNLDG_SHARE) - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, self.model_keys_to_ignore) + + selected_collaborators = [ + key for key, value in collab_weights_dict.items() if value > 0 + ] + self.comm_utils.send( + dest=self.server_node, + data=(selected_collaborators, self.get_model_weights()), + tag=self.tag.C_SELECTION, + ) + models_wts = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.KNLDG_SHARE + ) + + avg_wts = self.weighted_aggregate( + models_wts, collab_weights_dict, self.model_keys_to_ignore + ) # Average whole model by default self.set_model_weights(avg_wts, self.model_keys_to_ignore) - + self.round_stats["test_acc_before_training"] = self.local_test() - + prev_wts = self.get_model_weights() - self.round_stats["train_loss"], self.round_stats["train_acc"] = self.local_train(epochs_per_round) + self.round_stats["train_loss"], self.round_stats["train_acc"] = ( + self.local_train(epochs_per_round) + ) new_wts = self.get_model_weights() - self.round_stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) + self.round_stats["pseudo grad norm"] = self.compute_pseudo_grad_norm( + prev_wts, new_wts + ) - # Test updated model + # Test updated model self.round_stats["test_acc_after_training"] = self.local_test() - + self.log_clients_stats(collab_weights_dict, "Collaborator weights") - - self.comm_utils.send(dest=self.server_node, data=self.round_stats, tag=self.tag.ROUND_STATS) + + self.comm_utils.send( + dest=self.server_node, data=self.round_stats, tag=self.tag.ROUND_STATS + ) + class FedDataRepServer(BaseFedAvgServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils=comm_utils, comm_protocol=CommProtocol) # self.set_parameters() self.tag = CommProtocol self.config = config self.require_second_step = config["similarity_metric"] in TWO_STEP_STRAT - self.with_sim_sharing = self.config.get("consensus","") in SIM_SHARING + self.with_sim_sharing = self.config.get("consensus", "") in SIM_SHARING self.with_cons_step = self.config.get("consensus", "") in CONS_STEP @@ -1182,9 +1419,7 @@ def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> def send_models_selected(self, collaborator_selection, models_wts): for client_id, selected_clients in collaborator_selection.items(): wts = {key: models_wts[key] for key in selected_clients} - self.comm_utils.send( - dest=client_id, data=wts, tag=self.tag.KNLDG_SHARE - ) + self.comm_utils.send(dest=client_id, data=wts, tag=self.tag.KNLDG_SHARE) def single_round(self): """ @@ -1192,9 +1427,7 @@ def single_round(self): """ # Start local training - self.comm_utils.broadcast( - data=None, tag=self.tag.ROUND_START - ) + self.comm_utils.broadcast(data=None, tag=self.tag.ROUND_START) self.log_utils.log_console( "Server waiting for all clients to finish local training" ) @@ -1209,31 +1442,23 @@ def single_round(self): imgs = rep[0][:16, :3] self.log_utils.log_image(imgs, f"client{client+1}", self.round) - self.comm_utils.broadcast( - data=reprs, tag=self.tag.REPS1_SHARE - ) + self.comm_utils.broadcast(data=reprs, tag=self.tag.REPS1_SHARE) if self.require_second_step: # Collect the representations from all other nodes from the server reprs2 = self.comm_utils.all_gather(self.tag.REP2_ADVERT) reprs2 = {idx: reprs for idx, reprs in enumerate(reprs2, 1)} - self.comm_utils.broadcast( - data=reprs2, tag=self.tag.REPS2_SHARE - ) + self.comm_utils.broadcast(data=reprs2, tag=self.tag.REPS2_SHARE) if self.with_sim_sharing: sim_dicts = self.comm_utils.all_gather(self.tag.SIM_ADVERT) sim_dicts = {k: v for k, v in enumerate(sim_dicts, 1)} - self.comm_utils.broadcast( - data=sim_dicts, tag=self.tag.SIM_SHARE - ) + self.comm_utils.broadcast(data=sim_dicts, tag=self.tag.SIM_SHARE) if self.with_cons_step: consensus = self.comm_utils.all_gather(self.tag.CONS_ADVERT) consensus_dict = {idx: cons for idx, cons in enumerate(consensus, 1)} - self.comm_utils.broadcast( - data=consensus_dict, tag=self.tag.CONS_SHARE - ) + self.comm_utils.broadcast(data=consensus_dict, tag=self.tag.CONS_SHARE) data = self.comm_utils.all_gather(self.tag.C_SELECTION) collaborator_selection = { diff --git a/src/algos/fl_grid.py b/src/algos/fl_grid.py index 643fe6c..ea97659 100644 --- a/src/algos/fl_grid.py +++ b/src/algos/fl_grid.py @@ -1,5 +1,4 @@ import numpy as np -import math class GridTopology: @@ -26,17 +25,17 @@ def get_selected_ids(self, node_id, config): if node_id <= num_users - grid_size: selected_ids.append(node_id + grid_size) - if(num_users == 1): + if num_users == 1: selected_ids = [1] - elif(num_users == 2): - if(node_id == 1): + elif num_users == 2: + if node_id == 1: selected_ids = [2] else: selected_ids = [1] - elif(num_users == 3): - if(node_id == 1): + elif num_users == 3: + if node_id == 1: selected_ids = [2, 3] - elif(node_id == 2): + elif node_id == 2: selected_ids = [1] else: selected_ids = [1] diff --git a/src/algos/fl_isolated.py b/src/algos/fl_isolated.py index 3816621..a25d65c 100644 --- a/src/algos/fl_isolated.py +++ b/src/algos/fl_isolated.py @@ -1,6 +1,6 @@ from algos.base_class import BaseClient, BaseServer from utils.stats_utils import from_rounds_stats_per_client_per_round_to_dict_arrays -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager @@ -14,7 +14,9 @@ class CommProtocol(object): class FedIsoClient(BaseClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol @@ -83,13 +85,13 @@ def run_protocol(self): round_stats["test_acc"], ) ) - self.comm_utils.send( - dest=self.server_node, data=stats, tag=self.tag.DONE - ) + self.comm_utils.send(dest=self.server_node, data=stats, tag=self.tag.DONE) class FedIsoServer(BaseServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config diff --git a/src/algos/fl_random.py b/src/algos/fl_random.py index 5a67e2e..ab2a1d2 100644 --- a/src/algos/fl_random.py +++ b/src/algos/fl_random.py @@ -4,6 +4,7 @@ class RandomTopology: """Method docstring: Returns selected IDs based on some criteria.""" + def get_selected_ids(self, node_id, config, reprs_dict, communities): within_community_sampling = config.get("within_community_sampling", 1) diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 5c2b6e3..a90b84f 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -21,10 +21,15 @@ class FedStaticClient(BaseFedAvgClient): """ Federated Static Client Class. """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) - def get_collaborator_weights(self, reprs_dict: Dict[int, Any], rnd: int) -> Dict[int, float]: + def get_collaborator_weights( + self, reprs_dict: Dict[int, Any], rnd: int + ) -> Dict[int, float]: """ Returns the weights of the collaborators for the current round. """ @@ -47,18 +52,22 @@ def get_collaborator_weights(self, reprs_dict: Dict[int, Any], rnd: int) -> Dict own_aggr_weight, rnd, total_rounds ) - collab_weights[idx] = self._calculate_collab_weight(idx, own_aggr_weight, selected_ids) + collab_weights[idx] = self._calculate_collab_weight( + idx, own_aggr_weight, selected_ids + ) return collab_weights - def _decay_within_sampling(self, strategy: str, p: float, rnd: int, total_rounds: int) -> float: + def _decay_within_sampling( + self, strategy: str, p: float, rnd: int, total_rounds: int + ) -> float: """ Applies the within-community sampling decay strategy. """ if strategy == "linear_inc": - p *= (rnd / total_rounds) + p *= rnd / total_rounds elif strategy == "linear_dec": - p *= (1 - rnd / total_rounds) + p *= 1 - rnd / total_rounds elif strategy == "exp_inc": alpha = np.log((1 - p) / p) p *= np.exp(alpha * rnd / total_rounds) @@ -76,7 +85,9 @@ def _select_ids_based_on_algo(self, algo: str) -> List[int]: """ if algo == "random": topology = RandomTopology() - return topology.get_selected_ids(self.node_id, self.config, self.reprs_dict, self.communities) + return topology.get_selected_ids( + self.node_id, self.config, self.reprs_dict, self.communities + ) if algo == "ring": topology = RingTopology() return topology.get_selected_ids(self.node_id, self.config) @@ -88,7 +99,9 @@ def _select_ids_based_on_algo(self, algo: str) -> List[int]: return topology.get_selected_ids(self.node_id, self.config) return [] - def _apply_aggr_weight_strategy(self, weight: float, rnd: int, total_rounds: int) -> float: + def _apply_aggr_weight_strategy( + self, weight: float, rnd: int, total_rounds: int + ) -> float: """ Applies the aggregation weight strategy. """ @@ -98,15 +111,26 @@ def _apply_aggr_weight_strategy(self, weight: float, rnd: int, total_rounds: int target_weight = 0.5 if strategy == "linear": target_round = total_rounds // 2 - weight = 1 - (init_weight + (target_weight - init_weight) * (min(1, rnd / target_round))) + weight = 1 - ( + init_weight + + (target_weight - init_weight) * (min(1, rnd / target_round)) + ) elif strategy == "log": alpha = 0.05 - weight = 1 - (init_weight + (target_weight - init_weight) * (np.log(alpha * (rnd / total_rounds) + 1) / np.log(alpha + 1))) + weight = 1 - ( + init_weight + + (target_weight - init_weight) + * (np.log(alpha * (rnd / total_rounds) + 1) / np.log(alpha + 1)) + ) else: - raise ValueError(f"Aggregation weight strategy {strategy} not implemented") + raise ValueError( + f"Aggregation weight strategy {strategy} not implemented" + ) return weight - def _calculate_collab_weight(self, idx: int, own_aggr_weight: float, selected_ids: List[int]) -> float: + def _calculate_collab_weight( + self, idx: int, own_aggr_weight: float, selected_ids: List[int] + ) -> float: """ Calculates the collaborator weight. """ @@ -156,7 +180,9 @@ def flatten_repr(self, repr_dict: Dict[str, torch.Tensor]) -> torch.Tensor: params = [repr_dict[key].view(-1) for key in repr_dict.keys()] return torch.cat(params) - def compute_pseudo_grad_norm(self, prev_wts: Dict[str, torch.Tensor], new_wts: Dict[str, torch.Tensor]) -> float: + def compute_pseudo_grad_norm( + self, prev_wts: Dict[str, torch.Tensor], new_wts: Dict[str, torch.Tensor] + ) -> float: """ Computes the pseudo gradient norm. """ @@ -169,7 +195,9 @@ def run_protocol(self) -> None: print(f"Client {self.node_id} ready to start training") start_round = self.config.get("start_round", 0) if start_round != 0: - raise NotImplementedError("Start round different from 0 not implemented yet") + raise NotImplementedError( + "Start round different from 0 not implemented yet" + ) total_rounds = self.config["rounds"] epochs_per_round = self.config["epochs_per_round"] for rnd in range(start_round, total_rounds): @@ -183,13 +211,19 @@ def run_protocol(self) -> None: # Train locally and send the representation to the server if not self.config.get("local_train_after_aggr", False): - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) + stats["train_loss"], stats["train_acc"] = self.local_train( + epochs_per_round + ) repr_dict = self.get_representation() - self.comm_utils.send(dest=self.server_node, data=repr_dict, tag=self.tag.REPR_ADVERT) + self.comm_utils.send( + dest=self.server_node, data=repr_dict, tag=self.tag.REPR_ADVERT + ) # Collect the representations from all other nodes from the server - reprs = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPRS_SHARE) + reprs = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.REPRS_SHARE + ) reprs_dict = {k: v for k, v in enumerate(reprs, 1)} # Aggregate the representations based on the collaborator weights @@ -202,10 +236,17 @@ def run_protocol(self) -> None: if inter_commu_last_layer_to_aggr is not None and len( set(self.communities[self.node_id]).intersection(active_collab) ) != len(active_collab): - layer_idx = self.model_utils.models_layers_idx[self.config["model"]][inter_commu_last_layer_to_aggr] - layers_to_ignore = self.model_keys_to_ignore + list(list(models_wts.values())[0].keys())[layer_idx + 1:] - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore) + layer_idx = self.model_utils.models_layers_idx[self.config["model"]][ + inter_commu_last_layer_to_aggr + ] + layers_to_ignore = ( + self.model_keys_to_ignore + + list(list(models_wts.values())[0].keys())[layer_idx + 1 :] + ) + + avg_wts = self.weighted_aggregate( + models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore + ) self.set_model_weights(avg_wts, layers_to_ignore) if self.config.get("train_only_fc", False): @@ -219,9 +260,13 @@ def run_protocol(self) -> None: # Train locally and send the representation to the server if self.config.get("local_train_after_aggr", False): prev_wts = self.get_model_weights() - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) + stats["train_loss"], stats["train_acc"] = self.local_train( + epochs_per_round + ) new_wts = self.get_model_weights() - stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) + stats["pseudo grad norm"] = self.compute_pseudo_grad_norm( + prev_wts, new_wts + ) stats["test_acc_after_training"] = self.local_test() collab_weight = np.zeros(self.config["num_users"]) @@ -229,24 +274,33 @@ def run_protocol(self) -> None: collab_weight[k - 1] = v stats["Collaborator weights"] = collab_weight - self.comm_utils.send(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) + self.comm_utils.send( + dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS + ) class FedStaticServer(BaseFedAvgServer): """ Federated Static Server Class. """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.set_model_parameters(config) - self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + self.model_save_path = ( + f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" + ) def test(self) -> float: """ Test the model on the server. """ - _, acc = self.model_utils.test(self.model, self._test_loader, self.loss_fn, self.device) + _, acc = self.model_utils.test( + self.model, self._test_loader, self.loss_fn, self.device + ) if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) @@ -258,7 +312,9 @@ def single_round(self) -> List[Dict[str, Any]]: """ for client_node in self.users: self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") + self.log_utils.log_console( + "Server waiting for all clients to finish local training" + ) models = self.comm_utils.all_gather(self.tag.REPR_ADVERT) self.log_utils.log_console("Server received all clients models") @@ -267,7 +323,9 @@ def single_round(self) -> List[Dict[str, Any]]: clients_round_stats = self.comm_utils.all_gather(self.tag.ROUND_STATS) self.log_utils.log_console("Server received all clients stats") - self.log_utils.log_tb_round_stats(clients_round_stats, ["Collaborator weights"], self.round) + self.log_utils.log_tb_round_stats( + clients_round_stats, ["Collaborator weights"], self.round + ) self.log_utils.log_console( f"Round test acc before local training {[stats['test_acc_before_training'] for stats in clients_round_stats]}" diff --git a/src/algos/fl_torus.py b/src/algos/fl_torus.py index eb8592a..bdfe59c 100644 --- a/src/algos/fl_torus.py +++ b/src/algos/fl_torus.py @@ -12,7 +12,7 @@ def get_selected_ids(self, node_id, config): print(grid_size) selected_ids = [] - + num_rows = math.ceil(num_users / grid_size) # Left @@ -51,17 +51,17 @@ def get_selected_ids(self, node_id, config): # keep sampling identical across nodes (if same seed) selected_ids = list(set(selected_ids)) - if(num_users == 1): + if num_users == 1: selected_ids = [1] - elif(num_users == 2): - if(node_id == 1): + elif num_users == 2: + if node_id == 1: selected_ids = [2] else: selected_ids = [1] - elif(num_users == 3): - if(node_id == 1): + elif num_users == 3: + if node_id == 1: selected_ids = [2, 3] - elif(node_id == 2): + elif node_id == 2: selected_ids = [1] else: selected_ids = [1] diff --git a/src/algos/fl_val.py b/src/algos/fl_val.py index 810c8a3..c74a60d 100644 --- a/src/algos/fl_val.py +++ b/src/algos/fl_val.py @@ -190,9 +190,7 @@ def single_round(self): ) # Collect models from all clients - models = self.comm_utils.wait_for_all_clients( - self.users, self.tag.REPR_ADVERT - ) + models = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT) self.log_utils.log_console("Server received all clients models") # Broadcast the models to all clients diff --git a/src/algos/fl_weight.py b/src/algos/fl_weight.py index 4a3d8fe..1976512 100644 --- a/src/algos/fl_weight.py +++ b/src/algos/fl_weight.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict from utils.communication.comm_utils import CommunicationManager from torch import Tensor, cat import torch.nn as nn @@ -9,7 +9,9 @@ class FedWeightClient(BaseFedAvgClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config @@ -94,9 +96,7 @@ def get_k_higest_sim(self, sim_dict, k): selected_users.append(self.node_id) collaborator_dict = { - client_id: ( - 1 / len(selected_users) if client_id in selected_users else 0 - ) + client_id: (1 / len(selected_users) if client_id in selected_users else 0) for client_id in sim_dict.keys() } @@ -117,9 +117,7 @@ def run_protocol(self): self.round_stats = {} - self.comm_utils.receive( - node_ids=self.server_node, tag=self.tag.ROUND_START - ) + self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) repr = self.get_representation() warmup = self.config["warmup_epochs"] @@ -177,7 +175,9 @@ def run_protocol(self): class FedWeightServer(BaseFedAvgServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config @@ -191,9 +191,7 @@ def single_round(self): # Send signal to all users to start local training for client_node in self.users: - self.comm_utils.send( - dest=client_node, data=None, tag=self.tag.ROUND_START - ) + self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START) self.log_utils.log_console( "Server waiting for all users to finish local training" ) diff --git a/src/algos/generator.py b/src/algos/generator.py index 824f378..33f9427 100644 --- a/src/algos/generator.py +++ b/src/algos/generator.py @@ -115,7 +115,14 @@ def clone(self, copy_params: bool = True) -> "DeepGenerator": class DCGAN_Generator(nn.Module): """Generator from DCGAN: https://arxiv.org/abs/1511.06434""" - def __init__(self, nz: int = 100, ngf: int = 64, nc: int = 3, img_size: Union[int, List[int], Tuple[int, int]] = 64, slope: float = 0.2): + def __init__( + self, + nz: int = 100, + ngf: int = 64, + nc: int = 3, + img_size: Union[int, List[int], Tuple[int, int]] = 64, + slope: float = 0.2, + ): super(DCGAN_Generator, self).__init__() self.nz = nz if isinstance(img_size, (list, tuple)): @@ -220,7 +227,9 @@ class Discriminator(nn.Module): def __init__(self, nc: int = 3, img_size: int = 32): super().__init__() - def discriminator_block(in_filters: int, out_filters: int, bn: bool = True) -> nn.ModuleList: + def discriminator_block( + in_filters: int, out_filters: int, bn: bool = True + ) -> nn.ModuleList: block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), @@ -239,10 +248,7 @@ def discriminator_block(in_filters: int, out_filters: int, bn: bool = True) -> n # The height and width of downsampled image ds_size = img_size // 2**4 - self.adv_layer = nn.Sequential( - nn.Linear(128 * ds_size**2, 1), - nn.Sigmoid() - ) + self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size**2, 1), nn.Sigmoid()) def forward(self, img: torch.Tensor) -> torch.Tensor: out = self.model(img) diff --git a/src/algos/isolated.py b/src/algos/isolated.py index 267781d..687a630 100644 --- a/src/algos/isolated.py +++ b/src/algos/isolated.py @@ -18,7 +18,7 @@ def __init__(self, config) -> None: def set_training_data(self, config): train_dset = self.dset_obj.train_dset - test_dset = self.dset_obj.test_dset + self.dset_obj.test_dset samples_per_user = config["samples_per_user"] batch_size = config["batch_size"] client_idx = self.node_id diff --git a/src/algos/swarm.py b/src/algos/swarm.py index 5349e3f..abee7e4 100644 --- a/src/algos/swarm.py +++ b/src/algos/swarm.py @@ -3,8 +3,6 @@ from utils.communication.comm_utils import CommunicationManager from torch import Tensor, cat import torch.nn as nn -import random -import os from algos.base_class import BaseClient, BaseServer import numpy as np @@ -21,7 +19,9 @@ class CommProtocol(object): class SWARMClient(BaseClient): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol @@ -88,7 +88,9 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]) -> OrderedDict: avgd_wts[key] += coeff * local_wts[key].to(self.device) return avgd_wts - def aggregate(self, representation_list: List[OrderedDict[str, Tensor]]) -> OrderedDict: + def aggregate( + self, representation_list: List[OrderedDict[str, Tensor]] + ) -> OrderedDict: """ Aggregate the model weights """ @@ -134,7 +136,8 @@ def run_protocol(self) -> None: dest=self.server_node, data=self_repr, tag=self.tag.DONE ) print("Node {} waiting signal from node 1".format(self.node_id)) - repr = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.UPDATES + repr = self.comm_utils.receive( + node_ids=self.server_node, tag=self.tag.UPDATES ) self.set_representation(repr) @@ -144,7 +147,9 @@ def run_protocol(self) -> None: class SWARMServer(BaseServer): - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + def __init__( + self, config: Dict[str, Any], comm_utils: CommunicationManager + ) -> None: super().__init__(config, comm_utils) # self.set_parameters() self.config = config