diff --git a/src/algos/MetaL2C.py b/src/algos/MetaL2C.py index 8328fbc..a36c525 100644 --- a/src/algos/MetaL2C.py +++ b/src/algos/MetaL2C.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List +from utils.communication.comm_utils import CommunicationManager import math import torch import numpy as np @@ -76,8 +77,8 @@ def forward(self, model_dict): class MetaL2CClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) self.encoder = ModelEncoder(self.get_model_weights()) self.encoder_optim = optim.SGD( @@ -219,8 +220,8 @@ def run_protocol(self): round_stats = {} # Wait on server to start the round - avg_alpha = self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.ROUND_START + avg_alpha = self.comm_utils.receive( + self.server_node, tag=self.tag.ROUND_START ) # Load result of AllReduce of previous round if avg_alpha is not None: @@ -234,15 +235,15 @@ def run_protocol(self): ) repr = self.get_representation() ks_artifact = self.get_knowledge_sharing_artifact() - self.comm_utils.send_signal( + self.comm_utils.send( dest=self.server_node, data=(repr, ks_artifact), tag=self.tag.REPR_ADVERT, ) # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.REPRS_SHARE + reprs = self.comm_utils.receive( + self.server_node, tag=self.tag.REPRS_SHARE ) reprs_dict = {k: rep for k, (rep, _) in enumerate(reprs, 1)} @@ -286,7 +287,7 @@ def run_protocol(self): # AllReduce is computed by the server alpha = self.encoder.state_dict() - self.comm_utils.send_signal( + self.comm_utils.send( dest=self.server_node, data=(round_stats, alpha), tag=self.tag.ROUND_STATS, @@ -295,8 +296,8 @@ def run_protocol(self): class MetaL2CServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) # self.set_parameters() self.config = config self.set_model_parameters(config) @@ -321,7 +322,7 @@ def single_round(self, avg_alpha): # Send signal to all clients to start local training for client_node in self.users: - self.comm_utils.send_signal( + self.comm_utils.send( dest=client_node, data=avg_alpha, tag=self.tag.ROUND_START ) self.log_utils.log_console( @@ -329,16 +330,14 @@ def single_round(self, avg_alpha): ) # Collect representations (from all clients - reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.REPR_ADVERT) + reprs = self.comm_utils.all_gather(self.tag.REPR_ADVERT) self.log_utils.log_console("Server received all clients models") # Broadcast the representations to all clients self.send_representations(reprs) # Collect round stats from all clients - round_stats_and_alphas = self.comm_utils.wait_for_all_clients( - self.users, self.tag.ROUND_STATS - ) + round_stats_and_alphas = self.comm_utils.all_gather(self.tag.ROUND_STATS) alphas = [alpha for _, alpha in round_stats_and_alphas] round_stats = [stats for stats, _ in round_stats_and_alphas] diff --git a/src/algos/base_class.py b/src/algos/base_class.py index a57288c..4df105b 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -485,7 +485,7 @@ class BaseFedAvgClient(BaseClient): """ Abstract class for FedAvg based algorithms """ - def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol]) -> None: + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager, comm_protocol: type[CommProtocol] = CommProtocol) -> None: super().__init__(config, comm_utils) self.config = config self.model_save_path = "{}/saved_models/node_{}.pt".format( diff --git a/src/algos/fl_assigned.py b/src/algos/fl_assigned.py index db87990..d0afa3a 100644 --- a/src/algos/fl_assigned.py +++ b/src/algos/fl_assigned.py @@ -1,38 +1,40 @@ import numpy as np import math - +from typing import Any, Dict, List +from utils.communication.comm_utils import CommunicationManager from algos.base_class import BaseFedAvgClient, BaseFedAvgServer from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays class FedAssClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) - - def get_collaborator_weights(self, round): + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils=comm_utils) + + def get_collaborator_weights(self, num_collaborator, round): """ Returns the weights of the collaborators for the current round - """ - if self.config["strategy"] == "fixed": - collab_weights = { - id: 1 for id in self.config["assigned_collaborators"][self.node_id] - } - elif self.config["strategy"] == "direct_expo": - power = round % math.floor(math.log2(self.config["num_users"] - 1)) - steps = math.pow(2, power) - collab_id = int(((self.node_id + steps) % self.config["num_users"]) + 1) + """ + if self.config["strategy"] =="fixed": + collab_weights = {id: 1 for id in self.config["assigned_collaborators"][self.node_id]} + elif self.config["strategy"] =="direct_expo": + power = round % math.floor(math.log2(self.config["num_users"]-1)) + steps = math.pow(2,power) + collab_id = int(((self.node_id + steps) % self.config["num_users"]) + 1) collab_weights = {self.node_id: 1, collab_id: 1} + elif self.config["strategy"] =="random_among_assigned": + collab_weights = {k: 1 for k in np.random.choice(list(self.config["assigned_collaborators"][self.node_id]), size=num_collaborator, replace=False)} + collab_weights[self.node_id] = 1 else: raise ValueError("Strategy not implemented") - + total = sum(collab_weights.values()) collab_weights = {id: w / total for id, w in collab_weights.items()} return collab_weights - + def get_representation(self): return self.get_model_weights() - + def run_protocol(self): print(f"Client {self.node_id} ready to start training") start_round = self.config.get("start_round", 0) @@ -40,84 +42,70 @@ def run_protocol(self): epochs_per_round = self.config["epochs_per_round"] for round in range(start_round, total_rounds): stats = {} - + # Wait on server to start the round - self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.ROUND_START - ) - + self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) + repr = self.get_representation() - self.comm_utils.send_signal( - dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT - ) + self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal( - src=self.server_node, tag=self.tag.REPRS_SHARE - ) - - # In the future this dict might be generated by the server to send - # only requested models - reprs_dict = {k: v for k, v in enumerate(reprs, 1)} + # Collect the representations from all other nodes from the server + reprs = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPRS_SHARE) + + # In the future this dict might be generated by the server to send only requested models + reprs_dict = {k:v for k,v in enumerate(reprs, 1)} + + num_collaborator = self.config[f"target_users_{'before' if round < self.config['T_0'] else 'after'}_T_0"] # Aggregate the representations based on the collab weights - collab_weights_dict = self.get_collaborator_weights(round) - + collab_weights_dict = self.get_collaborator_weights(num_collaborator, round) + collaborators = [k for k, w in collab_weights_dict.items() if w > 0] - # If there are no collaborators, then the client does not update - # its model - if not (len(collaborators) == 1 and collaborators[0] == self.node_id): + # If there are no collaborators, then the client does not update its model + if not (len(collaborators)==1 and collaborators[0]==self.node_id): # Since clients representations are also used to transmit knowledge - # There is no need to fetch the server for the selected - # clients' knowledge + # There is no need to fetch the server for the selected clients' knowledge models_wts = reprs_dict - - avg_wts = self.weighted_aggregate( - models_wts, - collab_weights_dict, - keys_to_ignore=self.model_keys_to_ignore, - ) - + + avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=self.model_keys_to_ignore) + # Average whole model by default self.set_model_weights(avg_wts, self.model_keys_to_ignore) - + stats["test_acc_before_training"] = self.local_test() + stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - # Test updated model + + # Test updated model stats["test_acc_after_training"] = self.local_test() # Include collab weights in the stats collab_weight = np.zeros(self.config["num_users"]) - for k, v in collab_weights_dict.items(): - collab_weight[k - 1] = v + for k,v in collab_weights_dict.items(): + collab_weight[k-1] = v stats["collab_weights"] = collab_weight - - self.comm_utils.send_signal( - dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS - ) - + self.comm_utils.send(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) class FedAssServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils=comm_utils) # self.set_parameters() self.config = config self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format( - self.config["results_path"], self.node_id - ) - + self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], + self.node_id) + def test(self) -> float: """ Test the model on the server """ - test_loss, acc = self.model_utils.test( - self.model, self._test_loader, self.loss_fn, self.device - ) - # TODO save the model if the accuracy is better than the best accuracy - # so far + test_loss, acc = self.model_utils.test(self.model, + self._test_loader, + self.loss_fn, + self.device) + # TODO save the model if the accuracy is better than the best accuracy so far if acc > self.best_acc: self.best_acc = acc self.model_utils.save_model(self.model, self.model_save_path) @@ -127,40 +115,28 @@ def single_round(self): """ Runs the whole training procedure """ - + # Send signal to all clients to start local training for client_node in self.users: - self.comm_utils.send_signal( - dest=client_node, data=None, tag=self.tag.ROUND_START - ) - self.log_utils.log_console( - "Server waiting for all clients to finish local training" - ) - + self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START) + self.log_utils.log_console("Server waiting for all clients to finish local training") + # Collect models from all clients - models = self.comm_utils.wait_for_all_clients( - self.users, self.tag.REPR_ADVERT - ) - self.log_utils.log_console("Server received all clients models") - + models = self.comm_utils.all_gather(self.tag.REPR_ADVERT) + self.log_utils.log_console("Server received all clients models") + # Broadcast the models to all clients self.send_representations(models) - + # Collect round stats from all clients - round_stats = self.comm_utils.wait_for_all_clients( - self.users, self.tag.ROUND_STATS - ) - self.log_utils.log_console("Server received all clients stats") + round_stats = self.comm_utils.all_gather(self.tag.ROUND_STATS) + self.log_utils.log_console("Server received all clients stats") # Log the round stats on tensorboard except the collab weights self.log_utils.log_tb_round_stats(round_stats, ["collab_weights"], self.round) - self.log_utils.log_console( - f"Round acc TALT {[stats['test_acc_after_training'] for stats in round_stats]}" - ) - self.log_utils.log_console( - f"Round acc TBLT {[stats['test_acc_before_training'] for stats in round_stats]}" - ) + self.log_utils.log_console(f"Round acc TALT {[stats['test_acc_after_training'] for stats in round_stats]}") + self.log_utils.log_console(f"Round acc TBLT {[stats['test_acc_before_training'] for stats in round_stats]}") return round_stats @@ -179,6 +155,6 @@ def run_protocol(self): stats.append(round_stats) stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 + stats_dict["round_step"] = 1 self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) + self.plot_utils.plot_experiments_stats(stats_dict) \ No newline at end of file diff --git a/src/algos/fl_data_repr.py b/src/algos/fl_data_repr.py index 9c38643..1bdc773 100644 --- a/src/algos/fl_data_repr.py +++ b/src/algos/fl_data_repr.py @@ -3,6 +3,9 @@ import torch.nn as nn import torch.nn.functional as F +from typing import Any, Dict, List +from utils.communication.comm_utils import CommunicationManager + from algos.base_class import BaseFedAvgClient, BaseFedAvgServer from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays from torch.utils.data import DataLoader, Dataset @@ -42,8 +45,8 @@ class CommProtocol(object): class FedDataRepClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils=comm_utils, comm_protocol=CommProtocol) self.tag = CommProtocol self.require_second_step = config["similarity_metric"] in TWO_STEP_STRAT @@ -666,7 +669,7 @@ def select_top_k(self, collab_similarity, k, round, total_rounds): voters = sorted_collab[:num_voter] # + 1 to avoid receiving recommendation only about myself - self.comm_utils.send_signal( + self.comm_utils.send( dest=self.server_node, data=sorted_collab[: num_vote_per_voter + 1], tag=self.tag.CONS_ADVERT, @@ -769,7 +772,7 @@ def select_collaborator(self, collab_similarity, k, round, total_rounds): selected_collab = list(np.random.choice([key for key, _ in items], k, p=[value for _, value in items], replace=False)) elif strategy.endswith("top_x"): - top_x = self.config.get("num_clients_top_x", 0) + top_x = self.config.get("num_users_top_x", 0) if strategy == "growing_schedulded_top_x": top_x = 1 + int(top_x * round/total_rounds) @@ -794,7 +797,7 @@ def select_collaborator(self, collab_similarity, k, round, total_rounds): elif strategy == "xth": sorted_collab = [id for id, _ in sorted(collab_similarity.items(), key=lambda item: item[1], reverse=True)] - xth = self.config.get("num_clients_top_x", 0) + xth = self.config.get("num_users_top_x", 0) if xth >= len(sorted_collab): raise ValueError("xth {} must be smaller than the number of clients {}".format(xth, len(sorted_collab))) @@ -877,8 +880,8 @@ def vote_consensus(self, collab_similarity): voters = sorted_collab[:num_voter] # + 1 to avoid receiving recommendation only about myself - self.comm_utils.send_signal(dest=self.server_node, data=sorted_collab[:num_vote_per_voter+1], tag=self.tag.CONS_ADVERT) - vote_dict = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.CONS_SHARE) + self.comm_utils.send(dest=self.server_node, data=sorted_collab[:num_vote_per_voter+1], tag=self.tag.CONS_ADVERT) + vote_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.CONS_SHARE) candidate_list = [] for voter_id, votes in vote_dict.items(): @@ -1091,7 +1094,7 @@ def run_protocol(self): self.round_stats = {} # Wait on server to start the round - self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.ROUND_START) + self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START) if round == start_round: warmup_epochs = self.config.get("warmup_epochs", epochs_per_round) @@ -1100,23 +1103,23 @@ def run_protocol(self): print("Client {} warmup loss {} acc {}".format(self.node_id, warmup_loss, warmup_acc)) repr = self.get_representation() - self.comm_utils.send_signal(dest=self.server_node, data=repr, tag=self.tag.REP1_ADVERT) + self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REP1_ADVERT) # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPS1_SHARE) + reprs = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPS1_SHARE) reprs_dict = {k:v for k,v in enumerate(reprs, 1)} second_step_reprs_dict = None if self.require_second_step: reprs2 = self.get_second_step_representation(reprs_dict) - self.comm_utils.send_signal(dest=self.server_node, data=reprs2, tag=self.tag.REP2_ADVERT) - second_step_reprs_dict = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPS2_SHARE) + self.comm_utils.send(dest=self.server_node, data=reprs2, tag=self.tag.REP2_ADVERT) + second_step_reprs_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPS2_SHARE) similarity_dict = self.get_collaborators_similarity(round, reprs_dict, second_step_reprs_dict) if self.with_sim_sharing: - self.comm_utils.send_signal(dest=self.server_node, data=similarity_dict, tag=self.tag.SIM_ADVERT) - similarity_dict = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.SIM_SHARE) + self.comm_utils.send(dest=self.server_node, data=similarity_dict, tag=self.tag.SIM_ADVERT) + similarity_dict = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.SIM_SHARE) is_num_collab_changed = round == self.config['T_0'] and self.config["target_users_after_T_0"] != self.config["target_users_before_T_0"] is_selection_round = round == start_round or (round % self.config["rounds_per_selection"] == 0) or is_num_collab_changed @@ -1140,8 +1143,8 @@ def run_protocol(self): self.log_clients_stats(proba_dist, "Selection probability") selected_collaborators = [key for key, value in collab_weights_dict.items() if value > 0] - self.comm_utils.send_signal(dest=self.server_node, data=(selected_collaborators, self.get_model_weights()), tag=self.tag.C_SELECTION) - models_wts = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.KNLDG_SHARE) + self.comm_utils.send(dest=self.server_node, data=(selected_collaborators, self.get_model_weights()), tag=self.tag.C_SELECTION) + models_wts = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.KNLDG_SHARE) avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, self.model_keys_to_ignore) @@ -1160,11 +1163,11 @@ def run_protocol(self): self.log_clients_stats(collab_weights_dict, "Collaborator weights") - self.comm_utils.send_signal(dest=self.server_node, data=self.round_stats, tag=self.tag.ROUND_STATS) + self.comm_utils.send(dest=self.server_node, data=self.round_stats, tag=self.tag.ROUND_STATS) class FedDataRepServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils=comm_utils, comm_protocol=CommProtocol) # self.set_parameters() self.tag = CommProtocol self.config = config @@ -1179,7 +1182,7 @@ def __init__(self, config) -> None: def send_models_selected(self, collaborator_selection, models_wts): for client_id, selected_clients in collaborator_selection.items(): wts = {key: models_wts[key] for key in selected_clients} - self.comm_utils.send_signal( + self.comm_utils.send( dest=client_id, data=wts, tag=self.tag.KNLDG_SHARE ) @@ -1189,15 +1192,15 @@ def single_round(self): """ # Start local training - self.comm_utils.send_signal_to_all_clients( - self.users, data=None, tag=self.tag.ROUND_START + self.comm_utils.broadcast( + data=None, tag=self.tag.ROUND_START ) self.log_utils.log_console( "Server waiting for all clients to finish local training" ) # Collect representation from all clients - reprs = self.comm_utils.wait_for_all_clients(self.users, self.tag.REP1_ADVERT) + reprs = self.comm_utils.all_gather(self.tag.REP1_ADVERT) self.log_utils.log_console("Server received all clients reprs") if self.config["representation"] == "dreams": @@ -1206,39 +1209,33 @@ def single_round(self): imgs = rep[0][:16, :3] self.log_utils.log_image(imgs, f"client{client+1}", self.round) - self.comm_utils.send_signal_to_all_clients( - self.users, data=reprs, tag=self.tag.REPS1_SHARE + self.comm_utils.broadcast( + data=reprs, tag=self.tag.REPS1_SHARE ) if self.require_second_step: # Collect the representations from all other nodes from the server - reprs2 = self.comm_utils.wait_for_all_clients( - self.users, self.tag.REP2_ADVERT - ) + reprs2 = self.comm_utils.all_gather(self.tag.REP2_ADVERT) reprs2 = {idx: reprs for idx, reprs in enumerate(reprs2, 1)} - self.comm_utils.send_signal_to_all_clients( - self.users, data=reprs2, tag=self.tag.REPS2_SHARE + self.comm_utils.broadcast( + data=reprs2, tag=self.tag.REPS2_SHARE ) if self.with_sim_sharing: - sim_dicts = self.comm_utils.wait_for_all_clients( - self.users, self.tag.SIM_ADVERT - ) + sim_dicts = self.comm_utils.all_gather(self.tag.SIM_ADVERT) sim_dicts = {k: v for k, v in enumerate(sim_dicts, 1)} - self.comm_utils.send_signal_to_all_clients( - self.users, data=sim_dicts, tag=self.tag.SIM_SHARE + self.comm_utils.broadcast( + data=sim_dicts, tag=self.tag.SIM_SHARE ) if self.with_cons_step: - consensus = self.comm_utils.wait_for_all_clients( - self.users, self.tag.CONS_ADVERT - ) + consensus = self.comm_utils.all_gather(self.tag.CONS_ADVERT) consensus_dict = {idx: cons for idx, cons in enumerate(consensus, 1)} - self.comm_utils.send_signal_to_all_clients( - self.users, data=consensus_dict, tag=self.tag.CONS_SHARE + self.comm_utils.broadcast( + data=consensus_dict, tag=self.tag.CONS_SHARE ) - data = self.comm_utils.wait_for_all_clients(self.users, self.tag.C_SELECTION) + data = self.comm_utils.all_gather(self.tag.C_SELECTION) collaborator_selection = { idx: select for idx, (select, _) in enumerate(data, 1) } @@ -1248,9 +1245,7 @@ def single_round(self): self.send_models_selected(collaborator_selection, models_wts) # Collect round stats from all clients - clients_round_stats = self.comm_utils.wait_for_all_clients( - self.users, self.tag.ROUND_STATS - ) + clients_round_stats = self.comm_utils.all_gather(self.tag.ROUND_STATS) self.log_utils.log_console("Server received all clients stats") # Log the round stats on tensorboard diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 2b03f1f..1c07136 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -13,6 +13,21 @@ Tuple[Union[int, str, float, bool, None], ...], Optional[List[int]]]] +def assign_colab(clients): + groups = [1, 2] + + dict = {} + client = 1 + while client <= clients: + for size in groups: + group = [] + for i in range(size): + group.append(client) + client += 1 + for c in group: + dict[c] = group + return dict + # Algorithm Configuration iid_dispfl_clients_new: ConfigType = { @@ -256,5 +271,130 @@ "exp_keys": [], } +metaL2C_cifar10: ConfigType = { + "algo": "metal2c", + "sharing": "weights", #"updates" + "exp_id": "", + + # Client selection + "target_users_before_T_0": 0, + "target_users_after_T_0": 1, + "T_0": 2, + "K_0": 0, # number of peers to keep as neighbors at T_0 (!) inverse that in L2C paper + "T_0": 250, # round after wich only K_0 peers are kept + "alpha_lr": 0.1, + "alpha_weight_decay": 0.01, + + "epochs_per_round": 5, + "rounds": 3, + "model": "resnet18", + "average_last_layer": False, + "model_lr": 1e-4, + "batch_size": 64, + "optimizer": "sgd", + "weight_decay": 5e-4, + + # params for model + "position": 0, + "inp_shape": [128, 3, 32, 32], + + "exp_keys": [] +} + +fedass: ConfigType = { + "algo": "fedass", + "exp_id": "", + "num_rep": 1, + "load_existing": False, + + # Clients selection + "strategy": "random_among_assigned", # fixed, direct_expo + "assigned_collaborators": assign_colab(3), + "target_users_before_T_0": 0, + "target_users_after_T_0": 1, + "T_0": 10, # round after wich only target_users_after_T_0 peers are kept + + # Learning setup + "rounds": 10, + "epochs_per_round": 5, + "model": "resnet10", + # "pretrained": True, + # "train_only_fc": True, + "model_lr": 1e-4, + "batch_size": 16, + + # params for model + "position": 0, + "exp_keys": ["strategy"] +} + +feddatarepr: ConfigType = { + "algo": "feddatarepr", + "exp_id": "try2", + "num_rep": 1, + "load_existing": False, + + # Similarity params + "representation": "train_data", # "test_data", "train_data", "dreams" + "num_repr_samples": 16, + # "CTLR_KL" Collaborator is Teacher using Learner Representation + # "CTCR_KL" Collaborator is Teacher using Collaborator Representation - Default row + # "LTLR_KL" Collaborator is Learner using Learner Representation - Default column + # "CTAR_KL" Collaborator is Teacher using ALL Representations (from every other client) + # "train_loss_inv" : 1-loss/total + # "train_loss_sm": 1-softmax(losses) + "similarity_metric": "train_loss_inv", + + # Memory params + "sim_running_average": 10, + "sim_exclude_first": (5, 5), # (first rounds, first rounds after T0) + + # Clients selection + "target_users_before_T_0": 0, #feddatarepr_users-1, + "target_users_after_T_0": 1, + "T_0": 10, # round after wich only target_users_after_T_0 peers are kept + # highest, lowest, [lower_exp]_sim_sampling, top_x, xth, uniform_rdm + "selection_strategy": "uniform_rdm",#"uniform_rdm", + #"eps_greedy": 0.1, + # "num_users_top_x" : 1, # Ideally: size community-1 + # "selection_temperature": 0.5, # For all strategy with temperature + + + # Consensus params + # "sim_averaging", "sim_of_sim", "vote_1hop", "affinity_propagation_clustering", "mean_shift_clustering", "club" + "consensus":"mean_shift_clustering",# "affinity_propagation_clustering", + # "affinity_precomputed": False, # If False similarity row are treated as data points and not as similarity values + # "club_weak_link_strategy": "own_cluster_and_pointing_to", #"own_cluster_and_pointing_to", pointing_to, own_cluster + # "vote_consensus": (2,2), #( num_voter, num_vote_per_voter) + # "sim_consensus_top_a": 3, + + #"community_type": "dataset", + #"num_communities": len(domainnet_classes), + + # Learning setup + "warmup_epochs": 5, + "epochs_per_round": 5, + "rounds_per_selection": 1, # Number of rounds before selecting new collaborator(s) + "rounds": 10, + "model": "resnet10", + "average_last_layer": True, + "mask_finetune_last_layer": False, + "model_lr": 1e-4, + "batch_size": 16, + + # Dreams params + # "reprs_position": 0, + # "inp_shape": [3, 32, 32] , + # "inv_lr": 1e-1, + # "inv_epochs": 500, + # "alpha_preds": 0.1, + # "alpha_tv": 2.5e-3, + # "alpha_l2": 1e-7, + # "alpha_f": 10.0, + #"dreams_keep_best": False, # Use reprs with lowest loss + + "exp_keys": ["similarity_metric", "selection_strategy", "consensus"] +} + # Assign the current configuration -current_config: ConfigType = traditional_fl +current_config: ConfigType = feddatarepr diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 35614da..89ea029 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -1,9 +1,31 @@ # System Configuration # TODO: Set up multiple non-iid configurations here. The goal of a separate system config # is to simulate different real-world scenarios without changing the algorithm configuration. -from typing import Dict, List +from typing import TypeAlias, Dict, List, Union, Tuple, Optional # from utils.config_utils import get_sliding_window_support, get_device_ids +ConfigType: TypeAlias = Dict[str, Union[ + str, + float, + int, + bool, + List[str], + List[int], + List[float], + List[bool], + Tuple[Union[int, str, float, bool, None], ...], + Optional[List[int]]]] + +sliding_window_8c_4cpc_support = { + "1": [0, 1, 2, 3], + "2": [1, 2, 3, 4], + "3": [2, 3, 4, 5], + "4": [3, 4, 5, 6], + "5": [4, 5, 6, 7], + "6": [5, 6, 7, 8], + "7": [6, 7, 8, 9], + "8": [7, 8, 9, 0], +} def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]: """ @@ -67,7 +89,7 @@ def get_digit_five_support(num_users:int, domains:List[str]=DIGIT_FIVE): "comm": { "type": "MPI" }, - "num_users": 4, + "num_users": 3, # "experiment_path": "./experiments/", "dset": "cifar10", "dump_dir": "./expt_dump/", @@ -125,6 +147,27 @@ def get_digit_five_support(num_users:int, domains:List[str]=DIGIT_FIVE): "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal" } +mpi_metaL2C_support_sys_config = { + "comm": { + "type": "MPI" + }, + "seed": 1, + "num_users": 3, + # "experiment_path": "./experiments/", + "dset": "cifar10", + "dump_dir": "./expt_dump/", + "dpath": "./datasets/imgs/cifar10/", + "load_existing": False, + "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), + "train_label_distribution": "support", # Either "iid", "non_iid" "support", + "test_label_distribution": "support", # Either "iid" "support", + "support" : sliding_window_8c_4cpc_support, + "samples_per_user": 32, + "test_samples_per_user": 32, + "validation_prop": 0.05, + "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal" +} + mpi_digitfive_sys_config = { "comm": { "type": "MPI" @@ -208,4 +251,4 @@ def get_digit_five_support(num_users:int, domains:List[str]=DIGIT_FIVE): } # current_config = grpc_system_config -current_config = mpi_system_config +current_config:ConfigType = mpi_system_config diff --git a/src/utils/comm_utils.py b/src/utils/[outdated]comm_utils.py similarity index 100% rename from src/utils/comm_utils.py rename to src/utils/[outdated]comm_utils.py diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index f7e89f3..800d345 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -60,7 +60,7 @@ def receive(self, node_ids: str|int|List[str|int], tag:int=0) -> Any: else: return self.comm.receive(node_ids) - def broadcast(self, data: Any): + def broadcast(self, data: Any, tag:int=0): self.comm.broadcast(data) def all_gather(self, tag:int=0):