Skip to content

Commit

Permalink
Label Flipping Attack + Run on CPU (#139)
Browse files Browse the repository at this point in the history
* added support for label flipping attack, and caught a bug in fedstatic

* reverting back unecessary changes
  • Loading branch information
joyce-yuan authored Nov 5, 2024
1 parent 59b5bc6 commit 1bfe422
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
9 changes: 7 additions & 2 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,17 @@ def setup_cuda(self, config: Dict[str, ConfigType]) -> None:
self.device_ids = device_ids_map[node_name]
gpu_id = self.device_ids[0]

if torch.cuda.is_available():
if isinstance(gpu_id, int) and torch.cuda.is_available():
self.device = torch.device(f"cuda:{gpu_id}")
print(f"Using GPU: cuda:{gpu_id}")
else:
elif gpu_id == "cpu":
self.device = torch.device("cpu")
print("Using CPU")
else:
# Fallback in case of no GPU availability
self.device = torch.device("cpu")
print("Using CPU (Fallback)")


def set_model_parameters(self, config: Dict[str, Any]) -> None:
# Model related parameters
Expand Down
6 changes: 0 additions & 6 deletions src/algos/fl_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ def __init__(
super().__init__(config, comm_utils)
self.topology = select_topology(config, self.node_id)
self.topology.initialize()

def get_representation(self, **kwargs: Any) -> OrderedDict[str, torch.Tensor]:
"""
Returns the model weights as representation.
"""
return self.get_model_weights()

def get_neighbors(self) -> List[int]:
"""
Expand Down
7 changes: 7 additions & 0 deletions src/configs/malicious_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@
"corrupt_severity": 1,
}

# Label Flip Attack
label_flip: ConfigType = {
"malicious_type": "label_flip",
"permute_labels": 10,
# "permutation": random.shuffle([i for i in range(10)]),
}

# List of Malicious node configurations
malicious_config_list: Dict[str, ConfigType] = {
"sign_flip": sign_flip,
Expand Down
11 changes: 7 additions & 4 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
}


def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]:
def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]:
"""
Get the GPU device IDs for the users.
"""
# TODO: Make it multi-host
device_ids: Dict[str, List[int]] = {}
device_ids: Dict[str, List[int | Literal["cpu"]]] = {}
for i in range(num_users + 1): # +1 for the super-node
index = i % len(gpus_available)
gpu_id = gpus_available[index]
Expand Down Expand Up @@ -327,6 +327,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"dropout_correlation": 0.0, # correlation between dropouts of successive rounds: [0,1]
}


dropout_dicts = {"node_0": {}}
for i in range(1, num_users + 1):
dropout_dicts[f"node_{i}"] = dropout_dict
Expand All @@ -347,12 +348,14 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
"samples_per_user": 50000 // num_users, # distributed equally
"train_label_distribution": "iid",
"train_label_distribution": "non_iid",
"test_label_distribution": "iid",
"alpha_data": 1.0,
"exp_keys": [],
"dropout_dicts": dropout_dicts,
"log_memory": True,
"test_samples_per_user": 200,
}


current_config = grpc_system_config
# current_config = mpi_system_config
10 changes: 10 additions & 0 deletions src/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ def train_classification_malicious(

# Perform backpropagation with modified loss
loss.backward()
elif self.malicious_type == "label_flip":
# permutation = torch.tensor(self.config.get("permutation", [i for i in range(10)]))
permute_labels = self.config.get("permute_labels", 10)
permutation = torch.randperm(permute_labels)
permutation = permutation.to(target.device)

target = permutation[target] # flipped targets
loss = loss_fn(output, target)
loss.backward()

else:
loss = loss_fn(output, target)
loss.backward()
Expand Down

0 comments on commit 1bfe422

Please sign in to comment.