diff --git a/src/algos/fl_grid.py b/src/algos/fl_grid.py index 51e1682..4ed7cc1 100644 --- a/src/algos/fl_grid.py +++ b/src/algos/fl_grid.py @@ -1,285 +1,36 @@ -from collections import OrderedDict -from typing import Any, Dict, List -import torch -import torch.nn as nn -import random import numpy as np -from math import ceil +import math -from algos.base_class import BaseFedAvgClient, BaseFedAvgServer -from collections import defaultdict -from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays +class GridTopology(): + def get_selected_ids(node_id, config): + grid_size = int(config["num_clients"]**0.5) -class FedGridClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) - - def get_collaborator_weights(self, reprs_dict, round): - """ - Returns the weights of the collaborators for the current round - """ - total_rounds = self.config["rounds"] - within_community_sampling = self.config.get("within_community_sampling",1) - p_within_decay = self.config.get("p_within_decay",None) - if p_within_decay is not None: - if p_within_decay == "linear_inc": - within_community_sampling = within_community_sampling * (round / total_rounds) - elif p_within_decay == "linear_dec": - within_community_sampling = within_community_sampling * (1 - round / total_rounds) - elif p_within_decay == "exp_inc": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log((1-within_community_sampling)/within_community_sampling) - within_community_sampling = within_community_sampling * np.exp(alpha * round / total_rounds) - elif p_within_decay == "exp_dec": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log(within_community_sampling/(1-within_community_sampling)) - within_community_sampling = within_community_sampling * np.exp(- alpha * round / total_rounds) - elif p_within_decay == "log_inc": - alpha = np.exp(1/within_community_sampling)-1 - within_community_sampling = within_community_sampling * np.log2(1 + alpha * round / total_rounds) - - self.grid_size = int(self.config["num_clients"]**0.5) - - self.num_clients = self.config["num_clients"] + num_clients = config["num_clients"] selected_ids = [] # Left - if self.node_id % self.grid_size != 1: - selected_ids.append(self.node_id - 1) + if node_id % grid_size != 1: + selected_ids.append(node_id - 1) # Right - if self.node_id % self.grid_size != 0 and self.node_id < self.num_clients: - selected_ids.append(self.node_id + 1) + if node_id % grid_size != 0 and node_id < num_clients: + selected_ids.append(node_id + 1) # Top - if self.node_id > self.grid_size: - selected_ids.append(self.node_id - self.grid_size) + if node_id > grid_size: + selected_ids.append(node_id - grid_size) # Bottom - if self.node_id <= self.num_clients - self.grid_size: - selected_ids.append(self.node_id + self.grid_size) + if node_id <= num_clients - grid_size: + selected_ids.append(node_id + grid_size) - num_clients_to_select = self.config["num_clients_to_select"] + num_clients_to_select = config["num_clients_to_select"] # Force self node id to be selected, not removed before sampling to keep sampling identic across nodes (if same seed) selected_collabs = np.random.choice(selected_ids, size=min(num_clients_to_select, len(selected_ids)), replace=False) - selected_ids = list(selected_collabs) + [self.node_id] + selected_ids = list(selected_collabs) + [node_id] print("Selected collabs:" + str(selected_ids)) - - collab_weights = defaultdict(lambda: 0.0) - for idx in selected_ids: - own_aggr_weight = self.config.get("own_aggr_weight", 1/len(selected_ids)) - - aggr_weight_strategy = self.config.get("aggr_weight_strategy", None) - if aggr_weight_strategy is not None: - init_weight = 0.1 - target_weight = 0.5 - if aggr_weight_strategy == "linear": - target_round = total_rounds // 2 - own_aggr_weight = 1 - (init_weight + (target_weight - init_weight) * (min(1,round / target_round))) - elif aggr_weight_strategy == "log": - alpha = 0.05 - own_aggr_weight = 1 - (init_weight + (target_weight-init_weight) * (np.log(alpha*(round/total_rounds)+1)/np.log(alpha+1))) - else: - raise ValueError(f"Aggregation weight strategy {aggr_weight_strategy} not implemented") - - if self.node_id == 1 and idx == 1: - print(f"Collaborator {idx} weight: {own_aggr_weight}") - if idx == self.node_id: - collab_weights[idx] = own_aggr_weight - else: - collab_weights[idx] = (1 - own_aggr_weight) / (len(selected_ids) - 1) - - return collab_weights - - def get_representation(self): - return self.get_model_weights() - - def mask_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - key = [k for k in keys if "weight" in k][0] - weight = torch.zeros_like(wts[key]) - weight[self.classes_of_interest] = wts[key][self.classes_of_interest] - self.model.load_state_dict({key: weight.to(self.device)}, strict=False) - - def freeze_model_except_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - - for name, param in self.model.named_parameters(): - if name not in keys: - param.requires_grad = False - - def unfreeze_model(self): - for param in self.model.parameters(): - param.requires_grad = True - - def flatten_repr(self,repr): - params = [] - - for key in repr.keys(): - params.append(repr[key].view(-1)) - - params = torch.cat(params) - - return params - - def compute_pseudo_grad_norm(self, prev_wts, new_wts): - return np.linalg.norm(self.flatten_repr(prev_wts) - self.flatten_repr(new_wts)) - def run_protocol(self): - print(f"Client {self.node_id} ready to start training") - start_round = self.config.get("start_round", 0) - if start_round != 0: - raise NotImplementedError("Start round different from 0 not implemented yet") - total_rounds = self.config["rounds"] - epochs_per_round = self.config["epochs_per_round"] - for round in range(start_round, total_rounds): - stats = {} - - # Wait on server to start the round - self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.ROUND_START) - - if self.config.get("finetune_last_layer", False): - self.freeze_model_except_last_layer() - - # Train locally and send the representation to the server - if not self.config.get("local_train_after_aggr", False): - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - - repr = self.get_representation() - self.comm_utils.send_signal(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) - - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPRS_SHARE) - - # In the future this dict might be generated by the server to send only requested models - reprs_dict = {k:v for k,v in enumerate(reprs, 1)} - - # Aggregate the representations based on the collab weights - collab_weights_dict = self.get_collaborator_weights(reprs_dict, round) - - # Since clients representations are also used to transmit knowledge - # There is no need to fetch the server for the selected clients' knowledge - models_wts = reprs_dict - - layers_to_ignore = self.model_keys_to_ignore - active_collab = set([k for k,v in collab_weights_dict.items() if v > 0]) - inter_commu_last_layer_to_aggr = self.config.get("inter_commu_layer", None) - # If partial merging is on and some client selected client is outside the community, ignore layers after specified layer - if inter_commu_last_layer_to_aggr is not None and len(set(self.communities[self.node_id]).intersection(active_collab)) != len(active_collab): - layer_idx = self.model_utils.models_layers_idx[self.config["model"]][inter_commu_last_layer_to_aggr] - layers_to_ignore = self.model_keys_to_ignore + list(list(models_wts.values())[0].keys())[layer_idx+1:] - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore) - - # Average whole model by default - self.set_model_weights(avg_wts, layers_to_ignore) - - if self.config.get("train_only_fc", False): - - self.mask_last_layer() - self.freeze_model_except_last_layer() - self.local_train(1) - self.unfreeze_model() - - stats["test_acc_before_training"] = self.local_test() - - # Train locally and send the representation to the server - if self.config.get("local_train_after_aggr", False): - - prev_wts = self.get_model_weights() - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - new_wts = self.get_model_weights() - - stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) - - # Test updated model - stats["test_acc_after_training"] = self.local_test() - - # Include collab weights in the stats - collab_weight = np.zeros(self.config["num_clients"]) - for k,v in collab_weights_dict.items(): - collab_weight[k-1] = v - stats["Collaborator weights"] = collab_weight - - self.comm_utils.send_signal(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) - -class FedGridServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) - # self.set_parameters() - self.config = config - self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) - - def test(self) -> float: - """ - Test the model on the server - """ - test_loss, acc = self.model_utils.test(self.model, - self._test_loader, - self.loss_fn, - self.device) - # TODO save the model if the accuracy is better than the best accuracy so far - if acc > self.best_acc: - self.best_acc = acc - self.model_utils.save_model(self.model, self.model_save_path) - return acc - - def single_round(self): - """ - Runs the whole training procedure - """ - - # Send signal to all clients to start local training - for client_node in self.clients: - self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") - - # Collect models from all clients - models = self.comm_utils.wait_for_all_clients(self.clients, self.tag.REPR_ADVERT) - self.log_utils.log_console("Server received all clients models") - - # Broadcast the models to all clients - self.send_representations(models) - - # Collect round stats from all clients - clients_round_stats = self.comm_utils.wait_for_all_clients(self.clients, self.tag.ROUND_STATS) - self.log_utils.log_console("Server received all clients stats") - - # Log the round stats on tensorboard except the collab weights - self.log_utils.log_tb_round_stats(clients_round_stats, ["Collaborator weights"], self.round) - - # for stats in clients_round_stats: - # print(f"Collaborator weights: {stats['Collaborator weights']}") - - self.log_utils.log_console(f"Round test acc before local training {[stats['test_acc_before_training'] for stats in clients_round_stats]}") - self.log_utils.log_console(f"Round test acc after local training {[stats['test_acc_after_training'] for stats in clients_round_stats]}") - - return clients_round_stats - - def run_protocol(self): - self.log_utils.log_console("Starting static grid P2P collaboration") - start_round = self.config.get("start_round", 0) - total_round = self.config["rounds"] - - # List of list stats per round - stats = [] - for round in range(start_round, total_round): - self.round = round - self.log_utils.log_console("Starting round {}".format(round)) - - round_stats = self.single_round() - stats.append(round_stats) - - stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 - self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) - - + return selected_ids \ No newline at end of file diff --git a/src/algos/fl_random.py b/src/algos/fl_random.py index 20bb81d..93b6baa 100644 --- a/src/algos/fl_random.py +++ b/src/algos/fl_random.py @@ -1,264 +1,21 @@ -from collections import OrderedDict -from typing import Any, Dict, List -import torch -import torch.nn as nn import random import numpy as np -from algos.base_class import BaseFedAvgClient, BaseFedAvgServer -from collections import defaultdict -from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays +class RandomTopology(): + def get_selected_ids(node_id, config, reprs_dict, communities): + within_community_sampling = config.get("within_community_sampling",1) -class FedRanClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) - - def get_collaborator_weights(self, reprs_dict, round): - """ - Returns the weights of the collaborators for the current round - """ - total_rounds = self.config["rounds"] - within_community_sampling = self.config.get("within_community_sampling",1) - p_within_decay = self.config.get("p_within_decay",None) - if p_within_decay is not None: - if p_within_decay == "linear_inc": - within_community_sampling = within_community_sampling * (round / total_rounds) - elif p_within_decay == "linear_dec": - within_community_sampling = within_community_sampling * (1 - round / total_rounds) - elif p_within_decay == "exp_inc": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log((1-within_community_sampling)/within_community_sampling) - within_community_sampling = within_community_sampling * np.exp(alpha * round / total_rounds) - elif p_within_decay == "exp_dec": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log(within_community_sampling/(1-within_community_sampling)) - within_community_sampling = within_community_sampling * np.exp(- alpha * round / total_rounds) - elif p_within_decay == "log_inc": - alpha = np.exp(1/within_community_sampling)-1 - within_community_sampling = within_community_sampling * np.log2(1 + alpha * round / total_rounds) - - if random.random() <= within_community_sampling or len(self.communities) == 1: + if random.random() <= within_community_sampling or len(communities) == 1: # Consider only neighbors (clients in the same community) - indices = [id for id in sorted(list(reprs_dict.keys())) if id in self.communities[self.node_id]] + indices = [id for id in sorted(list(reprs_dict.keys())) if id in communities[node_id]] else: # Consider clients from other communities - indices = [id for id in sorted(list(reprs_dict.keys())) if id not in self.communities[self.node_id]] + indices = [id for id in sorted(list(reprs_dict.keys())) if id not in communities[node_id]] - num_clients_to_select = self.config[f"target_clients_{'before' if round < self.config['T_0'] else 'after'}_T_0"] + num_clients_to_select = config[f"target_clients_{'before' if round < config['T_0'] else 'after'}_T_0"] selected_ids = random.sample(indices, min(num_clients_to_select + 1, len(indices))) # Force self node id to be selected, not removed before sampling to keep sampling identic across nodes (if same seed) - selected_ids = [self.node_id] + [id for id in selected_ids if id != self.node_id][:num_clients_to_select] - - collab_weights = defaultdict(lambda: 0.0) - for idx in selected_ids: - own_aggr_weight = self.config.get("own_aggr_weight", 1/len(selected_ids)) - - aggr_weight_strategy = self.config.get("aggr_weight_strategy", None) - if aggr_weight_strategy is not None: - init_weight = 0.1 - target_weight = 0.5 - if aggr_weight_strategy == "linear": - target_round = total_rounds // 2 - own_aggr_weight = 1 - (init_weight + (target_weight - init_weight) * (min(1,round / target_round))) - elif aggr_weight_strategy == "log": - alpha = 0.05 - own_aggr_weight = 1 - (init_weight + (target_weight-init_weight) * (np.log(alpha*(round/total_rounds)+1)/np.log(alpha+1))) - else: - raise ValueError(f"Aggregation weight strategy {aggr_weight_strategy} not implemented") - - if self.node_id == 1 and idx == 1: - print(f"Collaborator {idx} weight: {own_aggr_weight}") - if idx == self.node_id: - collab_weights[idx] = own_aggr_weight - else: - collab_weights[idx] = (1 - own_aggr_weight) / (len(selected_ids) - 1) - - return collab_weights - - def get_representation(self): - return self.get_model_weights() - - def mask_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - key = [k for k in keys if "weight" in k][0] - weight = torch.zeros_like(wts[key]) - weight[self.classes_of_interest] = wts[key][self.classes_of_interest] - self.model.load_state_dict({key: weight.to(self.device)}, strict=False) + selected_ids = [node_id] + [id for id in selected_ids if id != node_id][:num_clients_to_select] - def freeze_model_except_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - - for name, param in self.model.named_parameters(): - if name not in keys: - param.requires_grad = False - - def unfreeze_model(self): - for param in self.model.parameters(): - param.requires_grad = True - - def flatten_repr(self,repr): - params = [] - - for key in repr.keys(): - params.append(repr[key].view(-1)) - - params = torch.cat(params) - - return params - - def compute_pseudo_grad_norm(self, prev_wts, new_wts): - return np.linalg.norm(self.flatten_repr(prev_wts) - self.flatten_repr(new_wts)) - - def run_protocol(self): - print(f"Client {self.node_id} ready to start training") - start_round = self.config.get("start_round", 0) - if start_round != 0: - raise NotImplementedError("Start round different from 0 not implemented yet") - total_rounds = self.config["rounds"] - epochs_per_round = self.config["epochs_per_round"] - for round in range(start_round, total_rounds): - stats = {} - - # Wait on server to start the round - self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.ROUND_START) - - if self.config.get("finetune_last_layer", False): - self.freeze_model_except_last_layer() - - # Train locally and send the representation to the server - if not self.config.get("local_train_after_aggr", False): - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - - repr = self.get_representation() - self.comm_utils.send_signal(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) - - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPRS_SHARE) - - # In the future this dict might be generated by the server to send only requested models - reprs_dict = {k:v for k,v in enumerate(reprs, 1)} - - # Aggregate the representations based on the collab weights - collab_weights_dict = self.get_collaborator_weights(reprs_dict, round) - - # Since clients representations are also used to transmit knowledge - # There is no need to fetch the server for the selected clients' knowledge - models_wts = reprs_dict - - layers_to_ignore = self.model_keys_to_ignore - active_collab = set([k for k,v in collab_weights_dict.items() if v > 0]) - inter_commu_last_layer_to_aggr = self.config.get("inter_commu_layer", None) - # If partial merging is on and some client selected client is outside the community, ignore layers after specified layer - if inter_commu_last_layer_to_aggr is not None and len(set(self.communities[self.node_id]).intersection(active_collab)) != len(active_collab): - layer_idx = self.model_utils.models_layers_idx[self.config["model"]][inter_commu_last_layer_to_aggr] - layers_to_ignore = self.model_keys_to_ignore + list(list(models_wts.values())[0].keys())[layer_idx+1:] - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore) - - # Average whole model by default - self.set_model_weights(avg_wts, layers_to_ignore) - - if self.config.get("train_only_fc", False): - - self.mask_last_layer() - self.freeze_model_except_last_layer() - self.local_train(1) - self.unfreeze_model() - - stats["test_acc_before_training"] = self.local_test() - - # Train locally and send the representation to the server - if self.config.get("local_train_after_aggr", False): - - prev_wts = self.get_model_weights() - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - new_wts = self.get_model_weights() - - stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) - - # Test updated model - stats["test_acc_after_training"] = self.local_test() - - # Include collab weights in the stats - collab_weight = np.zeros(self.config["num_clients"]) - for k,v in collab_weights_dict.items(): - collab_weight[k-1] = v - stats["Collaborator weights"] = collab_weight - - self.comm_utils.send_signal(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) - -class FedRanServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) - # self.set_parameters() - self.config = config - self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) - - def test(self) -> float: - """ - Test the model on the server - """ - test_loss, acc = self.model_utils.test(self.model, - self._test_loader, - self.loss_fn, - self.device) - # TODO save the model if the accuracy is better than the best accuracy so far - if acc > self.best_acc: - self.best_acc = acc - self.model_utils.save_model(self.model, self.model_save_path) - return acc - - def single_round(self): - """ - Runs the whole training procedure - """ - - # Send signal to all clients to start local training - for client_node in self.clients: - self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") - - # Collect models from all clients - models = self.comm_utils.wait_for_all_clients(self.clients, self.tag.REPR_ADVERT) - self.log_utils.log_console("Server received all clients models") - - # Broadcast the models to all clients - self.send_representations(models) - - # Collect round stats from all clients - clients_round_stats = self.comm_utils.wait_for_all_clients(self.clients, self.tag.ROUND_STATS) - self.log_utils.log_console("Server received all clients stats") - - # Log the round stats on tensorboard except the collab weights - self.log_utils.log_tb_round_stats(clients_round_stats, ["Collaborator weights"], self.round) - - self.log_utils.log_console(f"Round test acc before local training {[stats['test_acc_before_training'] for stats in clients_round_stats]}") - self.log_utils.log_console(f"Round test acc after local training {[stats['test_acc_after_training'] for stats in clients_round_stats]}") - - return clients_round_stats - - def run_protocol(self): - self.log_utils.log_console("Starting random P2P collaboration") - start_round = self.config.get("start_round", 0) - total_round = self.config["rounds"] - - # List of list stats per round - stats = [] - for round in range(start_round, total_round): - self.round = round - self.log_utils.log_console("Starting round {}".format(round)) - - round_stats = self.single_round() - stats.append(round_stats) - - stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 - self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) - - + return selected_ids diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 999af11..4b363ac 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -10,8 +10,12 @@ from collections import defaultdict from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays from fl_ring import RingTopology +from fl_grid import GridTopology +from fl_torus import TorusTopology +from fl_random import RandomTopology -class FedRingClient(BaseFedAvgClient): + +class FedStaticClient(BaseFedAvgClient): def __init__(self, config) -> None: super().__init__(config) @@ -39,13 +43,22 @@ def get_collaborator_weights(self, reprs_dict, round): alpha = np.exp(1/within_community_sampling)-1 within_community_sampling = within_community_sampling * np.log2(1 + alpha * round / total_rounds) - topology = self.config["topology"] - if topology == "ring": - ring_topology = RingTopology(self.config["num_clients"]) - selected_ids = ring_topology.get_selected_ids(self.node_id, self.config) - else: - selected_ids = self.get_selected_ids() - + algo = self.config["algo"] + if algo == "random": + topology = RandomTopology() + selected_ids = topology.get_selected_ids(self.node_id, self.config, self.reprs_dict, self.communities) + elif algo == "ring": + topology = RingTopology() + selected_ids = topology.get_selected_ids(self.node_id, self.config) + + elif algo == "grid": + topology = GridTopology() + selected_ids = topology.get_selected_ids(self.node_id, self.config) + + elif algo == "torus": + topology = TorusTopology() + selected_ids = topology.get_selected_ids(self.node_id, self.config) + collab_weights = defaultdict(lambda: 0.0) for idx in selected_ids: own_aggr_weight = self.config.get("own_aggr_weight", 1/len(selected_ids)) @@ -186,7 +199,7 @@ def run_protocol(self): self.comm_utils.send_signal(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) -class FedRingServer(BaseFedAvgServer): +class FedStaticServer(BaseFedAvgServer): def __init__(self, config) -> None: super().__init__(config) # self.set_parameters() diff --git a/src/algos/fl_torus.py b/src/algos/fl_torus.py index ce3892e..181ab64 100644 --- a/src/algos/fl_torus.py +++ b/src/algos/fl_torus.py @@ -1,303 +1,55 @@ -from collections import OrderedDict -from typing import Any, Dict, List -import torch -import torch.nn as nn -import random import numpy as np import math -from algos.base_class import BaseFedAvgClient, BaseFedAvgServer -from collections import defaultdict -from utils.stats_utils import from_round_stats_per_round_per_client_to_dict_arrays - -class FedTorusClient(BaseFedAvgClient): - def __init__(self, config) -> None: - super().__init__(config) - - def get_collaborator_weights(self, reprs_dict, round): - """ - Returns the weights of the collaborators for the current round - """ - total_rounds = self.config["rounds"] - within_community_sampling = self.config.get("within_community_sampling",1) - p_within_decay = self.config.get("p_within_decay",None) - if p_within_decay is not None: - if p_within_decay == "linear_inc": - within_community_sampling = within_community_sampling * (round / total_rounds) - elif p_within_decay == "linear_dec": - within_community_sampling = within_community_sampling * (1 - round / total_rounds) - elif p_within_decay == "exp_inc": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log((1-within_community_sampling)/within_community_sampling) - within_community_sampling = within_community_sampling * np.exp(alpha * round / total_rounds) - elif p_within_decay == "exp_dec": - # Alpha scaled so that it goes from p to (1-p) in R rounds - alpha = np.log(within_community_sampling/(1-within_community_sampling)) - within_community_sampling = within_community_sampling * np.exp(- alpha * round / total_rounds) - elif p_within_decay == "log_inc": - alpha = np.exp(1/within_community_sampling)-1 - within_community_sampling = within_community_sampling * np.log2(1 + alpha * round / total_rounds) - - self.grid_size = int(math.sqrt(self.config["num_clients"])) - self.num_clients = self.config["num_clients"] +class TorusTopology(): + def get_selected_ids(node_id, config): + grid_size = int(math.sqrt(config["num_clients"])) + num_clients = config["num_clients"] selected_ids = [] - num_rows = math.ceil(self.num_clients / self.grid_size) + num_rows = math.ceil(num_clients / grid_size) # Left - if self.node_id % self.grid_size != 1: - selected_ids.append(self.node_id - 1) - elif math.ceil(self.node_id / self.grid_size) * self.grid_size <= self.num_clients: - selected_ids.append(self.node_id + self.grid_size - 1) + if node_id % grid_size != 1: + selected_ids.append(node_id - 1) + elif math.ceil(node_id / grid_size) * grid_size <= num_clients: + selected_ids.append(node_id + grid_size - 1) # Right - if self.node_id % self.grid_size != 0 and self.node_id < self.num_clients: - right_id = self.node_id + 1 + if node_id % grid_size != 0 and node_id < num_clients: + right_id = node_id + 1 else: - node_row = math.ceil(self.node_id / self.grid_size) - right_id = 1 + self.grid_size * (node_row - 1) + node_row = math.ceil(node_id / grid_size) + right_id = 1 + grid_size * (node_row - 1) selected_ids.append(right_id) # Top - if self.node_id > self.grid_size: - top_id = self.node_id - self.grid_size + if node_id > grid_size: + top_id = node_id - grid_size else: - top_id = self.node_id + self.grid_size * (num_rows - 1) - if top_id > self.num_clients: - top_id = top_id - self.grid_size + top_id = node_id + grid_size * (num_rows - 1) + if top_id > num_clients: + top_id = top_id - grid_size selected_ids.append(top_id) # Bottom - if self.node_id <= self.num_clients - self.grid_size: - bottom_id = self.node_id + self.grid_size + if node_id <= num_clients - grid_size: + bottom_id = node_id + grid_size else: - bottom_id = self.node_id % self.grid_size + bottom_id = node_id % grid_size if bottom_id == 0: - bottom_id = self.grid_size + bottom_id = grid_size selected_ids.append(bottom_id) # Force self node id to be selected, not removed before sampling to keep sampling identical across nodes (if same seed) selected_ids = list(set(selected_ids)) - num_clients_to_select = self.config["num_clients_to_select"] + num_clients_to_select = config["num_clients_to_select"] selected_collabs = np.random.choice(selected_ids, size=min(num_clients_to_select, len(selected_ids)), replace=False) - selected_ids = list(selected_collabs) + [self.node_id] - - print("Selected collabs: " + str(self.node_id) + str(selected_ids)) - - collab_weights = defaultdict(lambda: 0.0) - for idx in selected_ids: - own_aggr_weight = self.config.get("own_aggr_weight", 1/len(selected_ids)) - - aggr_weight_strategy = self.config.get("aggr_weight_strategy", None) - if aggr_weight_strategy is not None: - init_weight = 0.1 - target_weight = 0.5 - if aggr_weight_strategy == "linear": - target_round = total_rounds // 2 - own_aggr_weight = 1 - (init_weight + (target_weight - init_weight) * (min(1,round / target_round))) - elif aggr_weight_strategy == "log": - alpha = 0.05 - own_aggr_weight = 1 - (init_weight + (target_weight-init_weight) * (np.log(alpha*(round/total_rounds)+1)/np.log(alpha+1))) - else: - raise ValueError(f"Aggregation weight strategy {aggr_weight_strategy} not implemented") - - if self.node_id == 1 and idx == 1: - print(f"Collaborator {idx} weight: {own_aggr_weight}") - if idx == self.node_id: - collab_weights[idx] = own_aggr_weight - else: - collab_weights[idx] = (1 - own_aggr_weight) / (len(selected_ids) - 1) - - return collab_weights - - def get_representation(self): - return self.get_model_weights() - - def mask_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - key = [k for k in keys if "weight" in k][0] - weight = torch.zeros_like(wts[key]) - weight[self.classes_of_interest] = wts[key][self.classes_of_interest] - self.model.load_state_dict({key: weight.to(self.device)}, strict=False) - - def freeze_model_except_last_layer(self): - wts = self.get_model_weights() - keys = self.model_utils.get_last_layer_keys(wts) - - for name, param in self.model.named_parameters(): - if name not in keys: - param.requires_grad = False - - def unfreeze_model(self): - for param in self.model.parameters(): - param.requires_grad = True - - def flatten_repr(self,repr): - params = [] - - for key in repr.keys(): - params.append(repr[key].view(-1)) - - params = torch.cat(params) - - return params - - def compute_pseudo_grad_norm(self, prev_wts, new_wts): - return np.linalg.norm(self.flatten_repr(prev_wts) - self.flatten_repr(new_wts)) - - def run_protocol(self): - print(f"Client {self.node_id} ready to start training") - start_round = self.config.get("start_round", 0) - if start_round != 0: - raise NotImplementedError("Start round different from 0 not implemented yet") - total_rounds = self.config["rounds"] - epochs_per_round = self.config["epochs_per_round"] - for round in range(start_round, total_rounds): - stats = {} - - # Wait on server to start the round - self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.ROUND_START) - - if self.config.get("finetune_last_layer", False): - self.freeze_model_except_last_layer() - - # Train locally and send the representation to the server - if not self.config.get("local_train_after_aggr", False): - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - - repr = self.get_representation() - self.comm_utils.send_signal(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT) - - # Collect the representations from all other nodes from the server - reprs = self.comm_utils.wait_for_signal(src=self.server_node, tag=self.tag.REPRS_SHARE) - - # In the future this dict might be generated by the server to send only requested models - reprs_dict = {k:v for k,v in enumerate(reprs, 1)} - - # Aggregate the representations based on the collab weights - collab_weights_dict = self.get_collaborator_weights(reprs_dict, round) - - # Since clients representations are also used to transmit knowledge - # There is no need to fetch the server for the selected clients' knowledge - models_wts = reprs_dict - - layers_to_ignore = self.model_keys_to_ignore - active_collab = set([k for k,v in collab_weights_dict.items() if v > 0]) - inter_commu_last_layer_to_aggr = self.config.get("inter_commu_layer", None) - # If partial merging is on and some client selected client is outside the community, ignore layers after specified layer - if inter_commu_last_layer_to_aggr is not None and len(set(self.communities[self.node_id]).intersection(active_collab)) != len(active_collab): - layer_idx = self.model_utils.models_layers_idx[self.config["model"]][inter_commu_last_layer_to_aggr] - layers_to_ignore = self.model_keys_to_ignore + list(list(models_wts.values())[0].keys())[layer_idx+1:] - - avg_wts = self.weighted_aggregate(models_wts, collab_weights_dict, keys_to_ignore=layers_to_ignore) - - # Average whole model by default - self.set_model_weights(avg_wts, layers_to_ignore) - - if self.config.get("train_only_fc", False): - - self.mask_last_layer() - self.freeze_model_except_last_layer() - self.local_train(1) - self.unfreeze_model() - - stats["test_acc_before_training"] = self.local_test() - - # Train locally and send the representation to the server - if self.config.get("local_train_after_aggr", False): - - prev_wts = self.get_model_weights() - stats["train_loss"], stats["train_acc"] = self.local_train(epochs_per_round) - new_wts = self.get_model_weights() - - stats["pseudo grad norm"] = self.compute_pseudo_grad_norm(prev_wts, new_wts) - - # Test updated model - stats["test_acc_after_training"] = self.local_test() - - # Include collab weights in the stats - collab_weight = np.zeros(self.config["num_clients"]) - for k,v in collab_weights_dict.items(): - collab_weight[k-1] = v - stats["Collaborator weights"] = collab_weight - - self.comm_utils.send_signal(dest=self.server_node, data=stats, tag=self.tag.ROUND_STATS) - -class FedTorusServer(BaseFedAvgServer): - def __init__(self, config) -> None: - super().__init__(config) - # self.set_parameters() - self.config = config - self.set_model_parameters(config) - self.model_save_path = "{}/saved_models/node_{}.pt".format(self.config["results_path"], - self.node_id) - - def test(self) -> float: - """ - Test the model on the server - """ - test_loss, acc = self.model_utils.test(self.model, - self._test_loader, - self.loss_fn, - self.device) - # TODO save the model if the accuracy is better than the best accuracy so far - if acc > self.best_acc: - self.best_acc = acc - self.model_utils.save_model(self.model, self.model_save_path) - return acc - - def single_round(self): - """ - Runs the whole training procedure - """ - - # Send signal to all clients to start local training - for client_node in self.clients: - self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.ROUND_START) - self.log_utils.log_console("Server waiting for all clients to finish local training") - - # Collect models from all clients - models = self.comm_utils.wait_for_all_clients(self.clients, self.tag.REPR_ADVERT) - self.log_utils.log_console("Server received all clients models") - - # Broadcast the models to all clients - self.send_representations(models) - - # Collect round stats from all clients - clients_round_stats = self.comm_utils.wait_for_all_clients(self.clients, self.tag.ROUND_STATS) - self.log_utils.log_console("Server received all clients stats") - - # Log the round stats on tensorboard except the collab weights - self.log_utils.log_tb_round_stats(clients_round_stats, ["Collaborator weights"], self.round) - for stats in clients_round_stats: - print(f"Collaborator weights: {stats['Collaborator weights']}") - - self.log_utils.log_console(f"Round test acc before local training {[stats['test_acc_before_training'] for stats in clients_round_stats]}") - self.log_utils.log_console(f"Round test acc after local training {[stats['test_acc_after_training'] for stats in clients_round_stats]}") - - return clients_round_stats - - def run_protocol(self): - self.log_utils.log_console("Starting static torus P2P collaboration") - start_round = self.config.get("start_round", 0) - total_round = self.config["rounds"] - - # List of list stats per round - stats = [] - for round in range(start_round, total_round): - self.round = round - self.log_utils.log_console("Starting round {}".format(round)) - - round_stats = self.single_round() - stats.append(round_stats) + selected_ids = list(selected_collabs) + [node_id] - stats_dict = from_round_stats_per_round_per_client_to_dict_arrays(stats) - stats_dict["round_step"] = 1 - self.log_utils.log_experiments_stats(stats_dict) - self.plot_utils.plot_experiments_stats(stats_dict) + print("Selected collabs: " + str(node_id) + str(selected_ids)) - + return selected_ids diff --git a/src/scheduler.py b/src/scheduler.py index 4dcff0e..79fcb24 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -3,13 +3,10 @@ from algos.base_class import BaseNode from algos.fl import FedAvgClient, FedAvgServer from algos.isolated import IsolatedServer -from algos.fl_random import FedRanClient, FedRanServer -from algos.fl_grid import FedGridClient, FedGridServer -from algos.fl_torus import FedTorusClient, FedTorusServer from algos.fl_assigned import FedAssClient, FedAssServer from algos.fl_isolated import FedIsoClient, FedIsoServer from algos.fl_weight import FedWeightClient, FedWeightServer -from algos.fl_ring import FedRingClient, FedRingServer +from algos.fl_static import FedStaticClient, FedStaticServer from algos.swarm import SWARMClient, SWARMServer from algos.DisPFL import DisPFLClient, DisPFLServer from algos.def_kt import DefKTClient,DefKTServer @@ -28,13 +25,10 @@ algo_map = { "fedavg": [FedAvgServer, FedAvgClient], "isolated": [IsolatedServer], - "fedran": [FedRanServer,FedRanClient], - "fedgrid": [FedGridServer,FedGridClient], - "fedtorus": [FedTorusServer,FedTorusClient], "fedass": [FedAssServer, FedAssClient], "fediso": [FedIsoServer,FedIsoClient], "fedweight": [FedWeightServer,FedWeightClient], - "fedring": [FedRingServer,FedRingClient], + "fedstatic": [FedStaticServer,FedStaticClient], "swarm" : [SWARMServer, SWARMClient], "dispfl": [DisPFLServer, DisPFLClient], "defkt": [DefKTServer,DefKTClient],