Skip to content

Commit

Permalink
add different algo for each client
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamjajoo committed Sep 16, 2024
1 parent 76d7767 commit 582e509
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 38 deletions.
18 changes: 18 additions & 0 deletions src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,21 @@ def assign_colab(clients):

# Assign the current configuration
current_config: ConfigType = feddatarepr

# List of algorithm configurations
algo_config_list: List[ConfigType] = [
iid_dispfl_clients_new,
traditional_fl,
fedweight,
defkt,
fedavg_object_detect,
fediso,
L2C,
fedcentral,
fedval,
swarm,
fedstatic,
metaL2C_cifar10,
fedass,
feddatarepr
]
30 changes: 30 additions & 0 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# is to simulate different real-world scenarios without changing the algorithm configuration.
from typing import TypeAlias, Dict, List, Union, Tuple, Optional
# from utils.config_utils import get_sliding_window_support, get_device_ids
from .algo_config import algo_config_list
import random

ConfigType: TypeAlias = Dict[str, Union[
str,
Expand Down Expand Up @@ -39,6 +41,30 @@ def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[
device_ids[f"node_{i}"] = [gpu_id]
return device_ids

def get_malicious_types(num_users: int, malicious_types: List[str], num_malicious: int) -> Dict[str, str]:
"""
Assign whether a node is malicious and, if so, the type of malicious node it is.
"""
malicious_nodes = random.sample(range(1, num_users + 1), num_malicious)
malicious_type_assignments = random.choices(malicious_types, k=num_malicious)

node_types: Dict[str, str] = {}
for i in range(num_users + 1): # +1 for the super-node
if i in malicious_nodes:
node_types[f"node_{i}"] = malicious_type_assignments[malicious_nodes.index(i)]
else:
node_types[f"node_{i}"] = "normal"
return node_types

def get_algo_configs(num_users: int, algo_configs: List[str]) -> Dict[str, str]:
"""
Randomly assign an algorithm configuration to each node, allowing for repetition.
"""
algo_config_map: Dict[str, str] = {}
for i in range(num_users + 1): # +1 for the super-node
algo_config_map[f"node_{i}"] = random.choice(algo_configs)
return algo_config_map

def get_domain_support(num_users: int, base: str, domains: List[int]|List[str]) -> Dict[str, str]:
assert num_users % len(domains) == 0

Expand Down Expand Up @@ -77,6 +103,8 @@ def get_camelyon17_support(num_users: int, domains: List[int]=CAMELYON17_DMN):
def get_digit_five_support(num_users:int, domains:List[str]=DIGIT_FIVE):
return get_domain_support(num_users, "", domains)

malicious_types = ["label_flipping", "zero_weights"]

digit_five_dpath = {
"mnist": "./imgs/mnist",
"usps": "./imgs/usps",
Expand All @@ -99,6 +127,8 @@ def get_digit_five_support(num_users:int, domains:List[str]=DIGIT_FIVE):
# The device_ids dictionary depicts the GPUs on which the nodes reside.
# For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing):
"device_ids": {"node_0": [0], "node_1": [0],"node_2": [0], "node_3": [0]},
"algo": get_algo_configs(num_users=3, algo_configs=algo_config_list),
"malicious_types": get_malicious_types(num_users=3, malicious_types = malicious_types, num_malicious=2),
"samples_per_user": 1000, #TODO: To model scenarios where different users have different number of samples
# we need to make this a dictionary with user_id as key and number of samples as value
"train_label_distribution": "iid", # Either "iid", "non_iid" "support"
Expand Down
36 changes: 8 additions & 28 deletions src/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from utils.communication.comm_utils import CommunicationManager
from utils.config_utils import load_config, process_config
from utils.log_utils import copy_source_code, check_and_create_path
from utils.node_map import NodeMap

# Mapping of algorithm names to their corresponding client and server classes so that they can be consumed by the scheduler later on.
algo_map = {
Expand Down Expand Up @@ -62,17 +61,11 @@ def get_node(config: Dict[str, Any], rank: int, comm_utils: CommunicationManager
algo_name = config["algo"]
node = algo_map[algo_name][rank > 0](config, comm_utils)

node_map = NodeMap()
malicious_type = node_map.get_malicious_type(node.node_id)
node_map.add_node(node.node_id, malicious_type)

return node

class Scheduler():
""" Manages the overall orchestration of experiments
"""
node_map = NodeMap()

def __init__(self) -> None:
pass

Expand All @@ -86,40 +79,28 @@ def assign_config_by_path(self, sys_config_path: str, algo_config_path: str, is_
else:
self.sys_config["comm"]["host"] = host
self.sys_config["comm"]["rank"] = None
self.algo_config = load_config(algo_config_path)
self.merge_configs()

def merge_configs(self) -> None:
self.config = {}
self.config.update(self.sys_config)

def merge_configs(self) -> None:
self.config.update(self.sys_config)
self.config.update(self.algo_config)

def malicious_simulation(self):
num_clients = self.config.get("num_users", 0)
num_malicious_clients = self.config.get("num_malicious_clients", 0)

if num_malicious_clients > num_clients:
raise ValueError("Number of malicious clients cannot exceed the total number of clients.")

possible_node_ids = list(range(1, num_clients))
malicious_clients = random.sample(possible_node_ids, num_malicious_clients)

node_map = NodeMap()
for node_id in range(num_clients):
malicious_type = 1 if node_id in malicious_clients else 0
node_map.add_node(node_id, malicious_type)

def initialize(self, copy_souce_code: bool=True) -> None:
assert self.config is not None, "Config should be set when initializing"
self.communication = CommunicationManager(self.config)

self.config["comm"]["rank"] = self.communication.get_rank()
# Base clients modify the seed later on
seed = self.config["seed"]
torch.manual_seed(seed) # type: ignore
random.seed(seed)
numpy.random.seed(seed)

node_name = "node_{}".format(self.communication.get_rank())

self.algo_config = self.sys_config["algo"][node_name]
self.merge_configs()

if self.communication.get_rank() == 0:
if copy_souce_code:
copy_source_code(self.config)
Expand All @@ -129,7 +110,6 @@ def initialize(self, copy_souce_code: bool=True) -> None:
os.mkdir(self.config["saved_models"])
os.mkdir(self.config["log_path"])

self.malicious_simulation()
self.node = get_node(self.config, rank=self.communication.get_rank(), comm_utils=self.communication)

def run_job(self) -> None:
Expand Down
18 changes: 8 additions & 10 deletions src/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,19 @@ def process_config(config: Dict[str, Any]) -> Dict[str, Any]:
else:
dset = config["dset"]

experiment_name = "{}_{}users_{}spc_{}_{}".format(
experiment_name = "{}_{}users_".format(
dset,
config["num_users"],
config["samples_per_user"],
config["algo"],
config["exp_id"],
)

for exp_key in config["exp_keys"]:
item = jmespath.search(exp_key, config)
assert item is not None
key = exp_key.split(".")[-1]
assert key is not None
# experiment_name += "_{}_{}".format(key, item)
experiment_name += "_{}".format(item)
# for exp_key in config["exp_keys"]:
# item = jmespath.search(exp_key, config)
# assert item is not None
# key = exp_key.split(".")[-1]
# assert key is not None
# # experiment_name += "_{}_{}".format(key, item)
# experiment_name += "_{}".format(item)

experiments_folder = config["dump_dir"]
results_path = experiments_folder + experiment_name + f"_seed{config['seed']}"
Expand Down

0 comments on commit 582e509

Please sign in to comment.