diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 9094b7f..12fa9f8 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -174,12 +174,17 @@ def setup_cuda(self, config: Dict[str, ConfigType]) -> None: self.device_ids = device_ids_map[node_name] gpu_id = self.device_ids[0] - if torch.cuda.is_available(): + if isinstance(gpu_id, int) and torch.cuda.is_available(): self.device = torch.device(f"cuda:{gpu_id}") print(f"Using GPU: cuda:{gpu_id}") - else: + elif gpu_id == "cpu": self.device = torch.device("cpu") print("Using CPU") + else: + # Fallback in case of no GPU availability + self.device = torch.device("cpu") + print("Using CPU (Fallback)") + def set_model_parameters(self, config: Dict[str, Any]) -> None: # Model related parameters diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 4148e56..59d7baa 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -20,12 +20,6 @@ def __init__( super().__init__(config, comm_utils) self.topology = select_topology(config, self.node_id) self.topology.initialize() - - def get_representation(self, **kwargs: Any) -> OrderedDict[str, torch.Tensor]: - """ - Returns the model weights as representation. - """ - return self.get_model_weights() def get_neighbors(self) -> List[int]: """ diff --git a/src/configs/malicious_config.py b/src/configs/malicious_config.py index e411f58..15025d4 100644 --- a/src/configs/malicious_config.py +++ b/src/configs/malicious_config.py @@ -45,6 +45,13 @@ "corrupt_severity": 1, } +# Label Flip Attack +label_flip: ConfigType = { + "malicious_type": "label_flip", + "permute_labels": 10, + # "permutation": random.shuffle([i for i in range(10)]), +} + # List of Malicious node configurations malicious_config_list: Dict[str, ConfigType] = { "sign_flip": sign_flip, diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 7041813..dc2abd6 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -27,12 +27,12 @@ } -def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]: +def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]: """ Get the GPU device IDs for the users. """ # TODO: Make it multi-host - device_ids: Dict[str, List[int]] = {} + device_ids: Dict[str, List[int | Literal["cpu"]]] = {} for i in range(num_users + 1): # +1 for the super-node index = i % len(gpus_available) gpu_id = gpus_available[index] @@ -327,6 +327,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dropout_correlation": 0.0, # correlation between dropouts of successive rounds: [0,1] } + dropout_dicts = {"node_0": {}} for i in range(1, num_users + 1): dropout_dicts[f"node_{i}"] = dropout_dict @@ -347,12 +348,14 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore "samples_per_user": 50000 // num_users, # distributed equally - "train_label_distribution": "iid", + "train_label_distribution": "non_iid", "test_label_distribution": "iid", + "alpha_data": 1.0, "exp_keys": [], "dropout_dicts": dropout_dicts, - "log_memory": True, + "test_samples_per_user": 200, } + current_config = grpc_system_config # current_config = mpi_system_config diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 19828b5..fbcabfd 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -258,6 +258,16 @@ def train_classification_malicious( # Perform backpropagation with modified loss loss.backward() + elif self.malicious_type == "label_flip": + # permutation = torch.tensor(self.config.get("permutation", [i for i in range(10)])) + permute_labels = self.config.get("permute_labels", 10) + permutation = torch.randperm(permute_labels) + permutation = permutation.to(target.device) + + target = permutation[target] # flipped targets + loss = loss_fn(output, target) + loss.backward() + else: loss = loss_fn(output, target) loss.backward()