From f0192d48df0019b6711f2cfe5b60e09e0d5d2466 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sat, 19 Oct 2024 10:15:11 -0400 Subject: [PATCH 01/19] added MPI Communication class --- src/utils/communication/mpi.py | 84 ++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 8f78f154..0bc7a9d6 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,6 +1,9 @@ from typing import Dict, Any, List from mpi4py import MPI from utils.communication.interface import CommunicationInterface +import threading +import time +from enum import Enum class MPICommUtils(CommunicationInterface): @@ -12,11 +15,11 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): def initialize(self): pass - def send(self, dest: str | int, data: Any): - self.comm.send(data, dest=int(dest)) + # def send(self, dest: str | int, data: Any): + # self.comm.send(data, dest=int(dest)) - def receive(self, node_ids: str | int) -> Any: - return self.comm.recv(source=int(node_ids)) + # def receive(self, node_ids: str | int) -> Any: + # return self.comm.recv(source=int(node_ids)) def broadcast(self, data: Any): for i in range(1, self.size): @@ -34,3 +37,76 @@ def all_gather(self): def finalize(self): pass + + +class MPICommunication(MPICommUtils): + def __init__(self, config: Dict[str, Dict[str, Any]]): + super().__init__(config) + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + self.send_event = threading.Event() + self.request_source: int | None = None + + def listener(self): + while True: + status = MPI.Status() + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status): + source = status.Get_source() + tag = status.Get_tag() + count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message + # If a message is available, receive it + data_to_recv = bytearray(count) + req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag) + req.wait() + # Convert the byte array back to a string + received_message = data_to_recv.decode('utf-8') + + if received_message == "Requesting Information": + self.send_event.set() + + self.send_event.clear() + break + time.sleep(1) # Simulate waiting time + + def send(self, dest: str | int, data: Any, tag: int): + while True: + # Wait until the listener thread detects a request + self.send_event.wait() + req = self.comm.isend(data, dest=int(dest), tag=tag) + req.wait() + + def receive(self, node_ids: str | int, tag: int) -> Any: + node_ids = int(node_ids) + message = "Requesting Information" + message_bytes = bytearray(message, 'utf-8') + send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag) + send_req.wait() + recv_req = self.comm.irecv(source=node_ids, tag=tag) + return recv_req.wait() + +# MPI Server +""" +initialization(): + node spins up listener thread, threading (an extra thread might not be needed since iprobe exists). + call listen? + +listen(): + listener thread starts listening for send requests (use iprobe and irecv for message) + when send request is received, call the send() function + +send(): + gather and send info to requesting node using comm.isend + comm.wait + +""" + +# MPI Client +""" +initialization(): + node is initialized + +receive(): + node sends request to sending node using isend() + node calls irecv and waits for response +""" + From 755fc073f76713f9a0ef6cecd62fa21b764444a7 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 21 Oct 2024 21:17:52 -0400 Subject: [PATCH 02/19] added send thread, merged 2 classes --- src/utils/communication/mpi.py | 144 +++++++++++++++------------------ 1 file changed, 65 insertions(+), 79 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 0bc7a9d6..89bd21cd 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -3,110 +3,96 @@ from utils.communication.interface import CommunicationInterface import threading import time -from enum import Enum - class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]]): + def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() - def initialize(self): - pass - - # def send(self, dest: str | int, data: Any): - # self.comm.send(data, dest=int(dest)) - - # def receive(self, node_ids: str | int) -> Any: - # return self.comm.recv(source=int(node_ids)) - - def broadcast(self, data: Any): - for i in range(1, self.size): - if i != self.rank: - self.send(i, data) - - def all_gather(self): - """ - This function is used to gather data from all the nodes. - """ - items: List[Any] = [] - for i in range(1, self.size): - items.append(self.receive(i)) - return items - - def finalize(self): - pass - - -class MPICommunication(MPICommUtils): - def __init__(self, config: Dict[str, Dict[str, Any]]): - super().__init__(config) + # Ensure that we are using thread safe threading level + self.required_threading_level = MPI.THREAD_MULTIPLE + self.threading_level = MPI.Query_thread() + # Make sure to check for MPI_THREAD_MULTIPLE threading level to support + # thread safe calls to send and recv + if self.required_threading_level > self.threading_level: + raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") + listener_thread = threading.Thread(target=self.listener, daemon=True) listener_thread.start() + send_thread = threading.Thread(target=self.send, args=(data)) + send_thread.start() + self.send_event = threading.Event() + # Ensures that the listener thread and send thread are not using self.request_source at the same time + self.source_node_lock = threading.Lock() self.request_source: int | None = None + def initialize(self): + pass + def listener(self): + """ + Runs on listener thread on each node to receive a send request + Once send request is received, the listener thread informs the main + thread to send the data to the requesting node. + """ while True: status = MPI.Status() - if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status): - source = status.Get_source() - tag = status.Get_tag() - count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message - # If a message is available, receive it - data_to_recv = bytearray(count) - req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag) - req.wait() - # Convert the byte array back to a string - received_message = data_to_recv.decode('utf-8') - - if received_message == "Requesting Information": - self.send_event.set() + # look for message with tag 1 (represents send request) + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): + with self.source_node_lock: + self.request_source = status.Get_source() - self.send_event.clear() - break + self.comm.irecv(source=self.request_source, tag=1) + self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, dest: str | int, data: Any, tag: int): + def send(self, data: Any): + """ + Node will wait until request is received and then send + data to requesting node. + """ while True: # Wait until the listener thread detects a request self.send_event.wait() - req = self.comm.isend(data, dest=int(dest), tag=tag) - req.wait() + with self.source_node_lock: + dest = self.request_source - def receive(self, node_ids: str | int, tag: int) -> Any: + if dest is not None: + req = self.comm.isend(data, dest=int(dest)) + req.wait() + + with self.source_node_lock: + self.request_source = None + + self.send_event.clear() + + def receive(self, node_ids: str | int) -> Any: + """ + Node will send a request and wait to receive data. + """ node_ids = int(node_ids) - message = "Requesting Information" - message_bytes = bytearray(message, 'utf-8') - send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag) + send_req = self.comm.isend("", dest=node_ids, tag=1) send_req.wait() - recv_req = self.comm.irecv(source=node_ids, tag=tag) + recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - -# MPI Server -""" -initialization(): - node spins up listener thread, threading (an extra thread might not be needed since iprobe exists). - call listen? - -listen(): - listener thread starts listening for send requests (use iprobe and irecv for message) - when send request is received, call the send() function - -send(): - gather and send info to requesting node using comm.isend - comm.wait -""" + # depreciated broadcast function + # def broadcast(self, data: Any): + # for i in range(1, self.size): + # if i != self.rank: + # self.send(i, data) -# MPI Client -""" -initialization(): - node is initialized + def all_gather(self): + """ + This function is used to gather data from all the nodes. + """ + items: List[Any] = [] + for i in range(1, self.size): + items.append(self.receive(i)) + return items -receive(): - node sends request to sending node using isend() - node calls irecv and waits for response -""" + def finalize(self): + pass From d37c35bd40c6817476a076265a9044e1a957131a Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 22 Oct 2024 10:30:21 -0400 Subject: [PATCH 03/19] improved comments --- src/utils/communication/mpi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 89bd21cd..3ea3e334 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -34,7 +34,7 @@ def initialize(self): def listener(self): """ Runs on listener thread on each node to receive a send request - Once send request is received, the listener thread informs the main + Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ while True: @@ -50,7 +50,7 @@ def listener(self): def send(self, data: Any): """ - Node will wait until request is received and then send + Node will wait for a request to send data and then send the data to requesting node. """ while True: @@ -70,7 +70,7 @@ def send(self, data: Any): def receive(self, node_ids: str | int) -> Any: """ - Node will send a request and wait to receive data. + Node will send a request for data and wait to receive data. """ node_ids = int(node_ids) send_req = self.comm.isend("", dest=node_ids, tag=1) @@ -78,7 +78,7 @@ def receive(self, node_ids: str | int) -> Any: recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - # depreciated broadcast function + # deprecated broadcast function # def broadcast(self, data: Any): # for i in range(1, self.size): # if i != self.rank: From 2f087e155129e1a868904f4df3f28ac6a2570e24 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 28 Oct 2024 17:16:10 -0400 Subject: [PATCH 04/19] testing mpi, model weights not acquired --- src/algos/fl.py | 2 +- src/configs/algo_config.py | 3 +- src/configs/sys_config.py | 16 ++--- src/main.py | 1 - src/scheduler.py | 2 - src/utils/communication/comm_utils.py | 4 +- src/utils/communication/mpi.py | 94 ++++++++++++++++++++------- 7 files changed, 85 insertions(+), 37 deletions(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index db805490..98a09a47 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -96,7 +96,7 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): num_users = len(model_wts) coeff = 1 / num_users avgd_wts: OrderedDict[str, Tensor] = OrderedDict() - + print(f"model weights: {model_wts}") for key in model_wts[0].keys(): avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 9a4aa764..60662d18 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -328,4 +328,5 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st ] -default_config_list: List[ConfigType] = [traditional_fl] +# default_config_list: List[ConfigType] = [traditional_fl] +default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 44ae73a0..fb88171f 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -145,11 +145,13 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "/mas/camera/Experiments/SONAR/abhi/" +DUMP_DIR = "/Users/kathryn/MIT/UROP/Media Lab/sonar_experiments/" +num_users = 4 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, + "num_users": num_users, "num_collaborators": NUM_COLLABORATORS, "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, @@ -159,14 +161,12 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( - num_users=3, - algo_configs=malicious_algo_config_list, - assignment_method="distribution", - distribution={0: 1, 1: 1, 2: 1}, + num_users=4, + algo_configs=default_config_list ), # type: ignore "samples_per_user": 1000, # TODO: To model scenarios where different users have different number of samples # we need to make this a dictionary with user_id as key and number of samples as value @@ -342,5 +342,5 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dropout_dicts": dropout_dicts, } -current_config = grpc_system_config -# current_config = mpi_system_config +# current_config = grpc_system_config +current_config = mpi_system_config diff --git a/src/main.py b/src/main.py index 655ac65f..d3a7c11d 100644 --- a/src/main.py +++ b/src/main.py @@ -66,6 +66,5 @@ scheduler.install_config() scheduler.initialize() - # Run the job scheduler.run_job() diff --git a/src/scheduler.py b/src/scheduler.py index 0aec0945..b1d0f7d6 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -104,7 +104,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: random.seed(seed) numpy.random.seed(seed) self.merge_configs() - if self.communication.get_rank() == 0: if copy_souce_code: copy_source_code(self.config) @@ -120,7 +119,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: # from a different machine print("Waiting for 10 seconds for the super node to create directories") time.sleep(10) - self.node = get_node( self.config, rank=self.communication.get_rank(), diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index fc4bc5df..788b0709 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -1,7 +1,7 @@ from enum import Enum from utils.communication.grpc.main import GRPCCommunication from typing import Any, Dict, List, TYPE_CHECKING -# from utils.communication.mpi import MPICommUtils +from utils.communication.mpi import MPICommUtils if TYPE_CHECKING: from algos.base_class import BaseNode @@ -20,7 +20,7 @@ def create_communication( ): comm_type = comm_type if comm_type == CommunicationType.MPI: - raise NotImplementedError("MPI's new version not yet implemented. Please use GRPC. See https://github.com/aidecentralized/sonar/issues/96 for more details.") + return MPICommUtils(config) elif comm_type == CommunicationType.GRPC: return GRPCCommunication(config) elif comm_type == CommunicationType.HTTP: diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 3ea3e334..70a96bf0 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,11 +1,16 @@ -from typing import Dict, Any, List +from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI from utils.communication.interface import CommunicationInterface import threading import time +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +import random + +if TYPE_CHECKING: + from algos.base_class import BaseNode class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): + def __init__(self, config: Dict[str, Dict[str, Any]]): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() @@ -18,19 +23,32 @@ def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): if self.required_threading_level > self.threading_level: raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() - send_thread = threading.Thread(target=self.send, args=(data)) - send_thread.start() - self.send_event = threading.Event() # Ensures that the listener thread and send thread are not using self.request_source at the same time - self.source_node_lock = threading.Lock() + self.lock = threading.Lock() self.request_source: int | None = None + self.is_working = True + self.communication_cost_received: int = 0 + self.communication_cost_sent: int = 0 + + self.base_node: BaseNode | None = None + + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + def initialize(self): pass + def register_self(self, obj: "BaseNode"): + self.base_node = obj + send_thread = threading.Thread(target=self.send) + send_thread.start() + + def get_comm_cost(self): + with self.lock: + return self.communication_cost_received, self.communication_cost_sent + def listener(self): """ Runs on listener thread on each node to receive a send request @@ -41,14 +59,28 @@ def listener(self): status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): - with self.source_node_lock: + with self.lock: self.request_source = status.Get_source() self.comm.irecv(source=self.request_source, tag=1) self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, data: Any): + def get_model(self) -> bytes | None: + print(f"getting model from {self.rank}, {self.base_node}") + if not self.base_node: + raise Exception("Base node not registered") + with self.lock: + if self.is_working: + print("model is working") + model = serialize_model(self.base_node.get_model_weights()) + print(f"model data to be sent: {model}") + else: + assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." + model = None + return model + + def send(self): """ Node will wait for a request to send data and then send the data to requesting node. @@ -56,33 +88,46 @@ def send(self, data: Any): while True: # Wait until the listener thread detects a request self.send_event.wait() - with self.source_node_lock: + with self.lock: dest = self.request_source if dest is not None: + data = self.get_model() req = self.comm.isend(data, dest=int(dest)) req.wait() - with self.source_node_lock: + with self.lock: self.request_source = None self.send_event.clear() - def receive(self, node_ids: str | int) -> Any: + def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ - node_ids = int(node_ids) - send_req = self.comm.isend("", dest=node_ids, tag=1) - send_req.wait() - recv_req = self.comm.irecv(source=node_ids) - return recv_req.wait() + max_tries = 10 + for node in node_ids: + while max_tries > 0: + try: + self.comm.send("", dest=node, tag=1) + recv_req = self.comm.irecv(source=node) + received_data = recv_req.wait() + print(f"received data: {received_data}") + return deserialize_model(received_data) + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(traceback.print_exc()) + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 # deprecated broadcast function - # def broadcast(self, data: Any): - # for i in range(1, self.size): - # if i != self.rank: - # self.send(i, data) + def broadcast(self, data: Any): + for i in range(1, self.size): + if i != self.rank: + self.comm.send(data, dest=i) def all_gather(self): """ @@ -90,9 +135,14 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): + print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items def finalize(self): pass + def set_is_working(self, is_working: bool): + with self.lock: + self.is_working = is_working + From 464a6748c28e8bd0e2fca6b7a66dea55bfc3b47e Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 3 Nov 2024 14:12:09 -0500 Subject: [PATCH 05/19] mpi works, occassional deadlock issue --- src/utils/communication/comm_utils.py | 1 + src/utils/communication/mpi.py | 144 ++++++++++++++++++++------ 2 files changed, 113 insertions(+), 32 deletions(-) diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index 788b0709..622da42e 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -2,6 +2,7 @@ from utils.communication.grpc.main import GRPCCommunication from typing import Any, Dict, List, TYPE_CHECKING from utils.communication.mpi import MPICommUtils +from mpi4py import MPI if TYPE_CHECKING: from algos.base_class import BaseNode diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 70a96bf0..ec2ec4a4 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,10 +1,12 @@ +from collections import OrderedDict from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI +from torch import Tensor from utils.communication.interface import CommunicationInterface import threading import time -from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model import random +import numpy as np if TYPE_CHECKING: from algos.base_class import BaseNode @@ -15,6 +17,9 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() + self.num_users: int = int(config["num_users"]) # type: ignore + self.finished = False + # Ensure that we are using thread safe threading level self.required_threading_level = MPI.THREAD_MULTIPLE self.threading_level = MPI.Query_thread() @@ -34,16 +39,17 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.base_node: BaseNode | None = None - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() + self.listener_thread = threading.Thread(target=self.listener) + self.listener_thread.start() + + self.send_thread = threading.Thread(target=self.send) def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - send_thread = threading.Thread(target=self.send) - send_thread.start() + self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -55,26 +61,30 @@ def listener(self): Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ - while True: + while not self.finished: status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: self.request_source = status.Get_source() - self.comm.irecv(source=self.request_source, tag=1) + print(f"Node {self.rank} received request from {self.request_source}") + # receive_request = self.comm.irecv(source=self.request_source, tag=1) + # receive_request.wait() + self.comm.recv(source=self.request_source, tag=1) self.send_event.set() - time.sleep(1) # Simulate waiting time + # time.sleep(1) + print(f"Node {self.rank} listener thread ended") - def get_model(self) -> bytes | None: + def get_model(self) -> List[OrderedDict[str, Tensor]] | None: print(f"getting model from {self.rank}, {self.base_node}") if not self.base_node: raise Exception("Base node not registered") with self.lock: if self.is_working: - print("model is working") - model = serialize_model(self.base_node.get_model_weights()) - print(f"model data to be sent: {model}") + model = self.base_node.get_model_weights() + model = [model] + print(f"Model from {self.rank} acquired") else: assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." model = None @@ -85,43 +95,62 @@ def send(self): Node will wait for a request to send data and then send the data to requesting node. """ - while True: + while not self.finished: # Wait until the listener thread detects a request self.send_event.wait() + if self.finished: + break with self.lock: dest = self.request_source if dest is not None: data = self.get_model() - req = self.comm.isend(data, dest=int(dest)) - req.wait() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) with self.lock: self.request_source = None self.send_event.clear() + print(f"Node {self.rank} send thread ended") def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ max_tries = 10 - for node in node_ids: - while max_tries > 0: - try: - self.comm.send("", dest=node, tag=1) - recv_req = self.comm.irecv(source=node) - received_data = recv_req.wait() - print(f"received data: {received_data}") - return deserialize_model(received_data) - except Exception as e: - print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") - import traceback - print(traceback.print_exc()) - # sleep for a random time between 1 and 10 seconds - random_time = random.randint(1, 10) - time.sleep(random_time) - max_tries -= 1 + assert len(node_ids) == 1, "Too many node_ids to unpack" + node = node_ids[0] + while max_tries > 0: + try: + print(f"Node {self.rank} receiving from {node}") + self.comm.send("", dest=node, tag=1) + # recv_req = self.comm.Irecv([], source=node) + # received_data = recv_req.wait() + received_data = self.comm.recv(source=node) + print(f"Node {self.rank} received data from {node}: {bool(received_data)}") + if not received_data: + raise Exception("Received empty data") + return received_data + except MPI.Exception as e: + print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + print(f"Node {self.rank} received") # deprecated broadcast function def broadcast(self, data: Any): @@ -138,9 +167,60 @@ def all_gather(self): print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items + + def send_finished(self): + self.comm.send("Finished", dest=0, tag=2) def finalize(self): - pass + # 1. All nodes send finished to the super node + # 2. super node will wait for all nodes to send finished + # 3. super node will then send bye to all nodes + # 4. all nodes will wait for the bye and then exit + # this is to ensure that all nodes have finished + # and no one leaves early + if self.rank == 0: + quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished + num_finished: set[int] = set() + status = MPI.Status() + while len(num_finished) < quorum_threshold: + # sleep for 5 seconds + print( + f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" + ) + # time.sleep(5) + # get finished nodes + self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) + print(f"received finish message from {status.Get_source()}") + num_finished.add(status.Get_source()) + + else: + # send finished to the super node + print(f"Node {self.rank} sent finish message") + self.send_finished() + + # problem: do the other nodes wait for super node to receive finish messages? + message = self.comm.bcast("Done", root=0) + self.finished = True + self.send_event.set() + print(f"Node {self.rank} received {message}, finished") + self.comm.Barrier() + self.listener_thread.join() + print(f"Node {self.rank} listener thread done") + if self.send_thread.is_alive(): + self.send_thread.join() + print(f"Node {self.rank} send thread done") + print(f"Node {self.rank} active threads: {threading.active_count()}") + print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") + print(f"Node {self.rank} {threading.enumerate()}") + # for thread in threading.enumerate(): + # if thread != threading.main_thread(): + # thread.join() + print(f"Node {self.rank} send thread is {self.send_thread.is_alive()}") + self.comm.Barrier() + print(f"Node {self.rank}: all nodes synchronized") + MPI.Finalize() + + print("Finalized") def set_is_working(self, is_working: bool): with self.lock: From 71dd9e86e3f87d0dcff40eddb62c08a9880ad613 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 5 Nov 2024 22:46:26 -0500 Subject: [PATCH 06/19] merged send and listener threads --- src/utils/communication/mpi.py | 54 ++++++++-------------------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index ec2ec4a4..026d6451 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -42,14 +42,11 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.listener_thread = threading.Thread(target=self.listener) self.listener_thread.start() - self.send_thread = threading.Thread(target=self.send) - def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -66,14 +63,14 @@ def listener(self): # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: - self.request_source = status.Get_source() + # self.request_source = status.Get_source() + dest = status.Get_source() print(f"Node {self.rank} received request from {self.request_source}") # receive_request = self.comm.irecv(source=self.request_source, tag=1) # receive_request.wait() - self.comm.recv(source=self.request_source, tag=1) - self.send_event.set() - # time.sleep(1) + self.comm.recv(source=dest, tag=1) + self.send(dest) print(f"Node {self.rank} listener thread ended") def get_model(self) -> List[OrderedDict[str, Tensor]] | None: @@ -90,31 +87,19 @@ def get_model(self) -> List[OrderedDict[str, Tensor]] | None: model = None return model - def send(self): + def send(self, dest: int): """ Node will wait for a request to send data and then send the data to requesting node. """ - while not self.finished: - # Wait until the listener thread detects a request - self.send_event.wait() - if self.finished: - break - with self.lock: - dest = self.request_source - - if dest is not None: - data = self.get_model() - print(f"Node {self.rank} is sending data to {dest}") - # req = self.comm.Isend(data, dest=int(dest)) - # req.wait() - self.comm.send(data, dest=int(dest)) - - with self.lock: - self.request_source = None - - self.send_event.clear() - print(f"Node {self.rank} send thread ended") + if self.finished: + return + + data = self.get_model() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) def receive(self, node_ids: List[int]) -> Any: """ @@ -183,11 +168,9 @@ def finalize(self): num_finished: set[int] = set() status = MPI.Status() while len(num_finished) < quorum_threshold: - # sleep for 5 seconds print( f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" ) - # time.sleep(5) # get finished nodes self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) print(f"received finish message from {status.Get_source()}") @@ -198,7 +181,6 @@ def finalize(self): print(f"Node {self.rank} sent finish message") self.send_finished() - # problem: do the other nodes wait for super node to receive finish messages? message = self.comm.bcast("Done", root=0) self.finished = True self.send_event.set() @@ -206,22 +188,12 @@ def finalize(self): self.comm.Barrier() self.listener_thread.join() print(f"Node {self.rank} listener thread done") - if self.send_thread.is_alive(): - self.send_thread.join() - print(f"Node {self.rank} send thread done") - print(f"Node {self.rank} active threads: {threading.active_count()}") print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") print(f"Node {self.rank} {threading.enumerate()}") - # for thread in threading.enumerate(): - # if thread != threading.main_thread(): - # thread.join() - print(f"Node {self.rank} send thread is {self.send_thread.is_alive()}") self.comm.Barrier() print(f"Node {self.rank}: all nodes synchronized") MPI.Finalize() - print("Finalized") - def set_is_working(self, is_working: bool): with self.lock: self.is_working = is_working From dc1fb851850e67b97802ed98f418bdaf837a2924 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Fri, 15 Nov 2024 13:12:34 -0500 Subject: [PATCH 07/19] first draft of test --- .github/workflows/train.yml | 48 +++++++++++++++++++++++++++++++++++++ src/configs/algo_config.py | 2 +- src/configs/sys_config.py | 16 +++++++------ 3 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/train.yml diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml new file mode 100644 index 00000000..b4ec3fc4 --- /dev/null +++ b/.github/workflows/train.yml @@ -0,0 +1,48 @@ +name: Test Training Code with gRPC + +on: + push: + branches: + # - main + - "*" + pull_request: + branches: + - main + +jobs: + train-check: + runs-on: ubuntu-latest + + steps: + # Step 1: Checkout the code + - name: Checkout repository + uses: actions/checkout@v3 + + # Step 2: Set up Python + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" # Specify the Python version you're using + + # Step 3: Install dependencies + - name: Install dependencies + run: pip install -r requirements.txt + + # Step 4: Run gRPC server and client + - name: Run test + run: | + # Start the gRPC server in the background + python /src/main.py -super true & + SERVER_PID=$! + + # Run the gRPC client to test communication + python main_grpc.py -n 5 -host localhost + + # Clean up the server process + kill $SERVER_PID + + # further checks: + # only 5 rounds + # gRPC only? or also MPI? + # num of samples + # num users and nodes diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index b31928fb..557e8186 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -192,7 +192,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore - "rounds": 20, + "rounds": 5, # Model parameters "model": "resnet10", diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index b69c4e19..2824cdd3 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -160,7 +160,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): NUM_COLLABORATORS = 1 DUMP_DIR = "/Users/kathryn/MIT/UROP/Media Lab/sonar_experiments/" -num_users = 4 +num_users = 3 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, @@ -169,19 +169,20 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, - "seed": 32, + # "seed": 32, + "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( - num_users=4, + num_users=3, algo_configs=default_config_list ), # type: ignore - "samples_per_user": 1000, # TODO: To model scenarios where different users have different number of samples + "samples_per_user": 5555, # TODO: To model scenarios where different users have different number of samples # we need to make this a dictionary with user_id as key and number of samples as value "train_label_distribution": "iid", # Either "iid", "non_iid" "support" "test_label_distribution": "iid", # Either "iid", "non_iid" "support" @@ -316,7 +317,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "exp_keys": [], } -num_users = 9 +num_users = 4 dropout_dict = { "distribution_dict": { # leave dict empty to disable dropout @@ -347,7 +348,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore - "samples_per_user": 50000 // num_users, # distributed equally + # "samples_per_user": 50000 // num_users, # distributed equally + "samples_per_user": 100, "train_label_distribution": "non_iid", "test_label_distribution": "iid", "alpha_data": 1.0, From 9c7f1b70394c9c6a659509841aa5419207a31bdd Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Fri, 15 Nov 2024 13:14:12 -0500 Subject: [PATCH 08/19] using python3.10 --- .github/workflows/train.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index b4ec3fc4..93e5ff3f 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.8" # Specify the Python version you're using + python-version: "3.10" # Specify the Python version you're using # Step 3: Install dependencies - name: Install dependencies From a818cbef96ff14fdd3e1d42edd3faf372f79f686 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 18 Nov 2024 11:19:45 -0500 Subject: [PATCH 09/19] made testing sys and algo configs --- .github/workflows/train.yml | 10 ++- src/configs/algo_config_test.py | 15 ++++ src/configs/sys_config_test.py | 125 ++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 src/configs/algo_config_test.py create mode 100644 src/configs/sys_config_test.py diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 93e5ff3f..19a08bdf 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -31,12 +31,18 @@ jobs: # Step 4: Run gRPC server and client - name: Run test run: | + mkdir ./sonar_experiments/ + # Start the gRPC server in the background - python /src/main.py -super true & + python /src/main.py -super true -b & "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" SERVER_PID=$! # Run the gRPC client to test communication - python main_grpc.py -n 5 -host localhost + python main_grpc.py -n 5 -host localhost -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" + + - name: Clean up + run: | + rm -rf ./sonar_experiments/ # Clean up the server process kill $SERVER_PID diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py new file mode 100644 index 00000000..e6693bc0 --- /dev/null +++ b/src/configs/algo_config_test.py @@ -0,0 +1,15 @@ +from utils.types import ConfigType + +fedstatic: ConfigType = { + # Collaboration setup + "algo": "fedstatic", + "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore + "rounds": 5, + + # Model parameters + "model": "resnet10", + "model_lr": 3e-4, + "batch_size": 256, +} + +# default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] \ No newline at end of file diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py new file mode 100644 index 00000000..fc0e9461 --- /dev/null +++ b/src/configs/sys_config_test.py @@ -0,0 +1,125 @@ +from typing import Dict, List, Literal, Optional +import random +from utils.types import ConfigType + +from .algo_config_test import ( + fedstatic +) + +def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]: + """ + Get the GPU device IDs for the users. + """ + # TODO: Make it multi-host + device_ids: Dict[str, List[int | Literal["cpu"]]] = {} + for i in range(num_users + 1): # +1 for the super-node + index = i % len(gpus_available) + gpu_id = gpus_available[index] + device_ids[f"node_{i}"] = [gpu_id] + return device_ids + + +def get_algo_configs( + num_users: int, + algo_configs: List[ConfigType], + assignment_method: Literal[ + "sequential", "random", "mapping", "distribution" + ] = "sequential", + seed: Optional[int] = 1, + mapping: Optional[List[int]] = None, + distribution: Optional[Dict[int, int]] = None, +) -> Dict[str, ConfigType]: + """ + Assign an algorithm configuration to each node, allowing for repetition. + sequential: Assigns the algo_configs sequentially to the nodes + random: Assigns the algo_configs randomly to the nodes + mapping: Assigns the algo_configs based on the mapping of node index to algo index provided + distribution: Assigns the algo_configs based on the distribution of algo index to number of nodes provided + """ + algo_config_map: Dict[str, ConfigType] = {} + algo_config_map["node_0"] = algo_configs[0] # Super-node + if assignment_method == "sequential": + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = algo_configs[i % len(algo_configs)] + elif assignment_method == "random": + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = random.choice(algo_configs) + elif assignment_method == "mapping": + if not mapping: + raise ValueError("Mapping must be provided for assignment method 'mapping'") + assert len(mapping) == num_users + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = algo_configs[mapping[i - 1]] + elif assignment_method == "distribution": + if not distribution: + raise ValueError( + "Distribution must be provided for assignment method 'distribution'" + ) + total_users = sum(distribution.values()) + assert total_users == num_users + + # List of node indices to assign + node_indices = list(range(1, total_users + 1)) + # Seed for reproducibility + random.seed(seed) + # Shuffle the node indices based on the seed + random.shuffle(node_indices) + + # Assign nodes based on the shuffled indices + current_index = 0 + for algo_index, num_nodes in distribution.items(): + for i in range(num_nodes): + node_id = node_indices[current_index] + algo_config_map[f"node_{node_id}"] = algo_configs[algo_index] + current_index += 1 + else: + raise ValueError(f"Invalid assignment method: {assignment_method}") + # print("algo config mapping is: ", algo_config_map) + return algo_config_map + +CIFAR10_DSET = "cifar10" +CIAR10_DPATH = "./datasets/imgs/cifar10/" + +DUMP_DIR = "./sonar_experiments/" + +NUM_COLLABORATORS = 1 +num_users = 4 + +dropout_dict = { + "distribution_dict": { # leave dict empty to disable dropout + "method": "uniform", # "uniform", "normal" + "parameters": {} # "mean": 0.5, "std": 0.1 in case of normal distribution + }, + "dropout_rate": 0.0, # cutoff for dropout: [0,1] + "dropout_correlation": 0.0, # correlation between dropouts of successive rounds: [0,1] +} + +dropout_dicts = {"node_0": {}} +for i in range(1, num_users + 1): + dropout_dicts[f"node_{i}"] = dropout_dict + +gpu_ids = [2, 3, 5, 6] + +grpc_system_config: ConfigType = { + "exp_id": "static", + "num_users": num_users, + "num_collaborators": NUM_COLLABORATORS, + "comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node + "dset": CIFAR10_DSET, + "dump_dir": DUMP_DIR, + "dpath": CIAR10_DPATH, + "seed": 2, + "device_ids": get_device_ids(num_users, gpu_ids), + # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore + # "samples_per_user": 50000 // num_users, # distributed equally + "samples_per_user": 100, + "train_label_distribution": "non_iid", + "test_label_distribution": "iid", + "alpha_data": 1.0, + "exp_keys": [], + "dropout_dicts": dropout_dicts, + "test_samples_per_user": 200, +} + +current_config = grpc_system_config \ No newline at end of file From 3c140863ce36703889c9631e58ace2f7eeda5cc5 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 18 Nov 2024 17:44:25 -0500 Subject: [PATCH 10/19] testing workflow --- .github/workflows/train.yml | 34 +++++++++++++++++++++++---------- src/configs/algo_config_test.py | 2 ++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 19a08bdf..0c3139f4 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -1,6 +1,7 @@ name: Test Training Code with gRPC on: + workflow_dispatch: push: branches: # - main @@ -9,6 +10,9 @@ on: branches: - main +env: + ACTIONS_STEP_DEBUG: true + jobs: train-check: runs-on: ubuntu-latest @@ -26,27 +30,37 @@ jobs: # Step 3: Install dependencies - name: Install dependencies - run: pip install -r requirements.txt + run: | + sudo apt update + sudo apt install -y libopenmpi-dev openmpi-bin + sudo apt-get install -y libgl1 libglib2.0-0 + + pip install -r requirements.txt # Step 4: Run gRPC server and client - name: Run test run: | - mkdir ./sonar_experiments/ + cd src + if [ -d "sonar_experiments " ]; then + echo "Directory exists. Removing..." + rm -rf sonar_experiments + fi + mkdir -p sonar_experiments - # Start the gRPC server in the background - python /src/main.py -super true -b & "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" - SERVER_PID=$! + mkdir -p ./sonar_experiments/ + chmod +x ./configs/algo_config_test.py - # Run the gRPC client to test communication - python main_grpc.py -n 5 -host localhost -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" + echo "starting main grpc" + python main_grpc.py -n 4 -host localhost + echo "starting main" + # python main.py -super true -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" + python main.py -super true -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" + echo "done" - name: Clean up run: | rm -rf ./sonar_experiments/ - # Clean up the server process - kill $SERVER_PID - # further checks: # only 5 rounds # gRPC only? or also MPI? diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index e6693bc0..eeae5d66 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + from utils.types import ConfigType fedstatic: ConfigType = { From 33e668a2430b20520da9fd1528db5efc5266c1af Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 24 Nov 2024 20:55:06 -0500 Subject: [PATCH 11/19] predict next move ish --- .vscode/settings.json | 11 ++++- src/scheduler.py | 2 + src/utils/communication/comm_utils.py | 3 ++ src/utils/communication/grpc/comm.proto | 5 +++ src/utils/communication/grpc/comm_pb2.py | 44 ++++++++++--------- src/utils/communication/grpc/comm_pb2_grpc.py | 43 ++++++++++++++++++ src/utils/communication/grpc/main.py | 28 +++++++++++- src/utils/communication/interface.py | 4 ++ 8 files changed, 116 insertions(+), 24 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 9e6483a3..745eb685 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,10 @@ { - "python.analysis.typeCheckingMode": "strict" -} \ No newline at end of file + "python.analysis.typeCheckingMode": "strict", + "sshfs.configs": [ + { + "name": "matlaber", + "host": "matlaber7.media.mit.edu", + "username": "kle" + } + ] +} diff --git a/src/scheduler.py b/src/scheduler.py index 23cc3271..8072419b 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -129,6 +129,8 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) + print("sending quorum now") + self.communication.send_quorum() def run_job(self) -> None: self.node.run_protocol() diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index e43a8b4b..fff02d86 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -71,6 +71,9 @@ def receive(self, node_ids: List[int]) -> Any: def broadcast(self, data: Any, tag: int = 0): self.comm.broadcast(data) + def send_quorum(self): + self.comm.send_quorum() + def all_gather(self, tag: int = 0): return self.comm.all_gather() diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto index 8f689c36..c69ade51 100644 --- a/src/utils/communication/grpc/comm.proto +++ b/src/utils/communication/grpc/comm.proto @@ -3,6 +3,7 @@ syntax = "proto3"; service CommunicationServer { + rpc send_status(Empty) returns (Status) {} rpc send_data (Data) returns (Empty) {} rpc send_model (Model) returns (Empty) {} rpc get_rank (Empty) returns (Rank) {} @@ -16,6 +17,10 @@ service CommunicationServer { message Empty {} +message Status{ + string message = 1; +} + message Model { bytes buffer = 1; } diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py index a9b03cea..43938394 100644 --- a/src/utils/communication/grpc/comm_pb2.py +++ b/src/utils/communication/grpc/comm_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xc1\x02\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x19\n\x06Status\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xe3\x02\n\x13\x43ommunicationServer\x12 \n\x0bsend_status\x12\x06.Empty\x1a\x07.Status\"\x00\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,24 +25,26 @@ _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' _globals['_EMPTY']._serialized_start=14 _globals['_EMPTY']._serialized_end=21 - _globals['_MODEL']._serialized_start=23 - _globals['_MODEL']._serialized_end=46 - _globals['_DATA']._serialized_start=48 - _globals['_DATA']._serialized_end=89 - _globals['_RANK']._serialized_start=91 - _globals['_RANK']._serialized_end=111 - _globals['_ROUND']._serialized_start=113 - _globals['_ROUND']._serialized_end=135 - _globals['_PORT']._serialized_start=137 - _globals['_PORT']._serialized_end=157 - _globals['_PEERID']._serialized_start=159 - _globals['_PEERID']._serialized_end=221 - _globals['_PEERIDS']._serialized_start=223 - _globals['_PEERIDS']._serialized_end=330 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=275 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=330 - _globals['_QUORUM']._serialized_start=332 - _globals['_QUORUM']._serialized_end=356 - _globals['_COMMUNICATIONSERVER']._serialized_start=359 - _globals['_COMMUNICATIONSERVER']._serialized_end=680 + _globals['_STATUS']._serialized_start=23 + _globals['_STATUS']._serialized_end=48 + _globals['_MODEL']._serialized_start=50 + _globals['_MODEL']._serialized_end=73 + _globals['_DATA']._serialized_start=75 + _globals['_DATA']._serialized_end=116 + _globals['_RANK']._serialized_start=118 + _globals['_RANK']._serialized_end=138 + _globals['_ROUND']._serialized_start=140 + _globals['_ROUND']._serialized_end=162 + _globals['_PORT']._serialized_start=164 + _globals['_PORT']._serialized_end=184 + _globals['_PEERID']._serialized_start=186 + _globals['_PEERID']._serialized_end=248 + _globals['_PEERIDS']._serialized_start=250 + _globals['_PEERIDS']._serialized_end=357 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=302 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=357 + _globals['_QUORUM']._serialized_start=359 + _globals['_QUORUM']._serialized_end=383 + _globals['_COMMUNICATIONSERVER']._serialized_start=386 + _globals['_COMMUNICATIONSERVER']._serialized_end=741 # @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py index e0258f72..ea45534d 100644 --- a/src/utils/communication/grpc/comm_pb2_grpc.py +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -39,6 +39,11 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.send_status = channel.unary_unary( + '/CommunicationServer/send_status', + request_serializer=comm__pb2.Empty.SerializeToString, + response_deserializer=comm__pb2.Status.FromString, + _registered_method=True) self.send_data = channel.unary_unary( '/CommunicationServer/send_data', request_serializer=comm__pb2.Data.SerializeToString, @@ -89,6 +94,12 @@ def __init__(self, channel): class CommunicationServerServicer(object): """Missing associated documentation comment in .proto file.""" + def send_status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def send_data(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -146,6 +157,11 @@ def send_finished(self, request, context): def add_CommunicationServerServicer_to_server(servicer, server): rpc_method_handlers = { + 'send_status': grpc.unary_unary_rpc_method_handler( + servicer.send_status, + request_deserializer=comm__pb2.Empty.FromString, + response_serializer=comm__pb2.Status.SerializeToString, + ), 'send_data': grpc.unary_unary_rpc_method_handler( servicer.send_data, request_deserializer=comm__pb2.Data.FromString, @@ -202,6 +218,33 @@ def add_CommunicationServerServicer_to_server(servicer, server): class CommunicationServer(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def send_status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_status', + comm__pb2.Empty.SerializeToString, + comm__pb2.Status.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def send_data(request, target, diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 7fb687ac..1d4f6c5a 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -175,6 +175,9 @@ def update_port( self.peer_ids[request.rank.rank]["ip"] = request.ip # type: ignore self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore return comm_pb2.Empty() # type: ignore + + def send_status(self, request, context) -> comm_pb2.Status: + return comm_pb2.Status(message="Ready") # type: ignore def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # type: ignore """ @@ -343,7 +346,31 @@ def initialize(self): peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids) ) stub.send_peer_ids(proto_msg) # type: ignore + # stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + + def send_quorum(self): + if self.rank == 0: + for peer_id in self.servicer.peer_ids: + host_ip = self.servicer.peer_ids[peer_id].get("ip") + if peer_id != self.rank: + port = self.servicer.peer_ids[peer_id].get("port") + address = f"{host_ip}:{port}" + print(f"Sending peer_ids to {address}") + with grpc.insecure_channel(address) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore + stub.send_status(comm_pb2.Empty()) # type: ignore + print("status message sent") + else: + for peer_id in self.servicer.peer_ids: + host_ip = self.servicer.peer_ids[peer_id].get("ip") + if peer_id != self.rank: + port = self.servicer.peer_ids[peer_id].get("port") + address = f"{host_ip}:{port}" + with grpc.insecure_channel(address) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore + # status = stub.send_status(comm_pb2.Empty()) # type: ignore stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + print("quorum sent") def get_host_from_rank(self, rank: int) -> str: for peer_id in self.servicer.peer_ids: @@ -370,7 +397,6 @@ def send_with_retries(self, dest_host: str, buffer: Any) -> Any: raise Exception("Failed to send data. Receiver unreachable.") - def send(self, dest: str | int, data: OrderedDict[str, Any]): """ data should be a python dictionary diff --git a/src/utils/communication/interface.py b/src/utils/communication/interface.py index b1daac27..8df3c067 100644 --- a/src/utils/communication/interface.py +++ b/src/utils/communication/interface.py @@ -18,6 +18,10 @@ def send(self, dest: str | int, data: Any): def receive(self, node_ids: List[int]) -> Any: pass + @abstractmethod + def send_quorum(self) -> Any: + pass + @abstractmethod def broadcast(self, data: Any): pass From 6cea0a9e47d06196bb85b514e1bb4b03ae6fd9e6 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 26 Nov 2024 14:15:53 -0500 Subject: [PATCH 12/19] moved quorum send --- .github/workflows/train.yml | 41 ++++++++++++++++------------ src/configs/algo_config_test.py | 4 +-- src/configs/sys_config_test.py | 2 +- src/scheduler.py | 3 +- src/utils/communication/grpc/main.py | 28 ++++--------------- src/utils/communication/mpi.py | 4 +++ 6 files changed, 37 insertions(+), 45 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 0c3139f4..c6b0ce29 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -22,6 +22,20 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 + # - name: check directories + # run: | + # cd src + # DIR="../../../../../../../home" + # if [ -d "$DIR" ]; then + # ### Take action if $DIR exists ### + # echo "Installing config files in ${DIR}" + # exit 1 + # else + # ### Control will jump here if $DIR does NOT exists ### + # echo "Error: ${DIR} not found. Can not continue." + # exit 1 + # fi + # Step 2: Set up Python - name: Set up Python uses: actions/setup-python@v4 @@ -41,28 +55,19 @@ jobs: - name: Run test run: | cd src - if [ -d "sonar_experiments " ]; then - echo "Directory exists. Removing..." - rm -rf sonar_experiments - fi - mkdir -p sonar_experiments - - mkdir -p ./sonar_experiments/ - chmod +x ./configs/algo_config_test.py + # chmod +x ./configs/algo_config_test.py echo "starting main grpc" python main_grpc.py -n 4 -host localhost echo "starting main" - # python main.py -super true -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" - python main.py -super true -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" + python main.py -super true -s "./configs/sys_config_test.py" echo "done" - - name: Clean up - run: | - rm -rf ./sonar_experiments/ + # - name: Clean up + # run: | - # further checks: - # only 5 rounds - # gRPC only? or also MPI? - # num of samples - # num users and nodes + # further checks: + # only 5 rounds + # gRPC only? or also MPI? + # num of samples + # num users and nodes diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index eeae5d66..a64c012d 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,12 +1,10 @@ -#!/usr/bin/env python3 - from utils.types import ConfigType fedstatic: ConfigType = { # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore - "rounds": 5, + "rounds": 1, # Model parameters "model": "resnet10", diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index fc0e9461..ba1ace73 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -80,7 +80,7 @@ def get_algo_configs( CIFAR10_DSET = "cifar10" CIAR10_DPATH = "./datasets/imgs/cifar10/" -DUMP_DIR = "./sonar_experiments/" +DUMP_DIR = "../../../../../../../home/" NUM_COLLABORATORS = 1 num_users = 4 diff --git a/src/scheduler.py b/src/scheduler.py index 8072419b..14f488bd 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -109,6 +109,7 @@ def initialize(self, copy_souce_code: bool = True) -> None: numpy.random.seed(seed) self.merge_configs() if self.communication.get_rank() == 0: + print("initializing super node") if copy_souce_code: copy_source_code(self.config) else: @@ -129,7 +130,7 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) - print("sending quorum now") + print(f"Node {self.communication.get_rank()} finished get ndoe") self.communication.send_quorum() def run_job(self) -> None: diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 1d4f6c5a..b850f0e3 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -175,9 +175,6 @@ def update_port( self.peer_ids[request.rank.rank]["ip"] = request.ip # type: ignore self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore return comm_pb2.Empty() # type: ignore - - def send_status(self, request, context) -> comm_pb2.Status: - return comm_pb2.Status(message="Ready") # type: ignore def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # type: ignore """ @@ -349,28 +346,15 @@ def initialize(self): # stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore def send_quorum(self): + """ Send the quorum status to all nodes after peer IDs are sent. """ if self.rank == 0: for peer_id in self.servicer.peer_ids: - host_ip = self.servicer.peer_ids[peer_id].get("ip") - if peer_id != self.rank: - port = self.servicer.peer_ids[peer_id].get("port") - address = f"{host_ip}:{port}" - print(f"Sending peer_ids to {address}") - with grpc.insecure_channel(address) as channel: # type: ignore - stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore - stub.send_status(comm_pb2.Empty()) # type: ignore - print("status message sent") - else: - for peer_id in self.servicer.peer_ids: - host_ip = self.servicer.peer_ids[peer_id].get("ip") - if peer_id != self.rank: - port = self.servicer.peer_ids[peer_id].get("port") - address = f"{host_ip}:{port}" - with grpc.insecure_channel(address) as channel: # type: ignore - stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore - # status = stub.send_status(comm_pb2.Empty()) # type: ignore + if not self.is_own_id(peer_id): + host = self.get_host_from_rank(peer_id) + with grpc.insecure_channel(host) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore - print("quorum sent") + print(f"Quorum status sent to all nodes.") def get_host_from_rank(self, rank: int) -> str: for peer_id in self.servicer.peer_ids: diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 026d6451..e9b20042 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -45,6 +45,10 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): def initialize(self): pass + def send_quorum(self) -> Any: + # return super().send_quorum(node_ids) + pass + def register_self(self, obj: "BaseNode"): self.base_node = obj From 23c3252e5d4bb63fb021aecdad7fec8e5c936b41 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 26 Nov 2024 14:18:26 -0500 Subject: [PATCH 13/19] moved quorum send --- src/configs/sys_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 2824cdd3..aa9c6854 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -158,7 +158,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "/Users/kathryn/MIT/UROP/Media Lab/sonar_experiments/" +DUMP_DIR = "../../../../../../../home/" num_users = 3 mpi_system_config: ConfigType = { From 3e93738babe61e50fe5d2d169b3a9849c16baa31 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:23:03 -0500 Subject: [PATCH 14/19] using traditional fl algo --- src/algos/fl.py | 1 - src/configs/algo_config.py | 2 +- src/configs/algo_config_test.py | 17 ++++++++++++++--- src/configs/sys_config_test.py | 4 ++-- src/scheduler.py | 2 -- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 8a115f6e..cecdb2a8 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -104,7 +104,6 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): num_users = len(model_wts) coeff = 1 / num_users avgd_wts: OrderedDict[str, Tensor] = OrderedDict() - print(f"model weights: {model_wts}") for key in model_wts[0].keys(): avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 557e8186..1b859b38 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -31,7 +31,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st traditional_fl: ConfigType = { # Collaboration setup "algo": "fedavg", - "rounds": 2, + "rounds": 1, # Model parameters "model": "resnet10", diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index a64c012d..2f2c7fcf 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,9 +1,20 @@ from utils.types import ConfigType -fedstatic: ConfigType = { +# fedstatic: ConfigType = { +# # Collaboration setup +# "algo": "fedstatic", +# "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore +# "rounds": 1, + +# # Model parameters +# "model": "resnet10", +# "model_lr": 3e-4, +# "batch_size": 256, +# } + +traditional_fl: ConfigType = { # Collaboration setup - "algo": "fedstatic", - "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore + "algo": "fedavg", "rounds": 1, # Model parameters diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index ba1ace73..6616328b 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -3,7 +3,7 @@ from utils.types import ConfigType from .algo_config_test import ( - fedstatic + traditional_fl ) def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]: @@ -111,7 +111,7 @@ def get_algo_configs( "seed": 2, "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore - "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore # "samples_per_user": 50000 // num_users, # distributed equally "samples_per_user": 100, "train_label_distribution": "non_iid", diff --git a/src/scheduler.py b/src/scheduler.py index 14f488bd..55da449d 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -109,7 +109,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: numpy.random.seed(seed) self.merge_configs() if self.communication.get_rank() == 0: - print("initializing super node") if copy_souce_code: copy_source_code(self.config) else: @@ -130,7 +129,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) - print(f"Node {self.communication.get_rank()} finished get ndoe") self.communication.send_quorum() def run_job(self) -> None: From 66f3c440e634a9b7d689c4fd26e93716b0dacd4b Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:25:41 -0500 Subject: [PATCH 15/19] run test only during push to main --- .github/workflows/train.yml | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index c6b0ce29..a9e67ecb 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -4,8 +4,8 @@ on: workflow_dispatch: push: branches: - # - main - - "*" + - main + # - "*" pull_request: branches: - main @@ -22,20 +22,6 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - # - name: check directories - # run: | - # cd src - # DIR="../../../../../../../home" - # if [ -d "$DIR" ]; then - # ### Take action if $DIR exists ### - # echo "Installing config files in ${DIR}" - # exit 1 - # else - # ### Control will jump here if $DIR does NOT exists ### - # echo "Error: ${DIR} not found. Can not continue." - # exit 1 - # fi - # Step 2: Set up Python - name: Set up Python uses: actions/setup-python@v4 @@ -63,9 +49,6 @@ jobs: python main.py -super true -s "./configs/sys_config_test.py" echo "done" - # - name: Clean up - # run: | - # further checks: # only 5 rounds # gRPC only? or also MPI? From c765013b7fa381789d08bb9f7afd2a11f939ce13 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:30:31 -0500 Subject: [PATCH 16/19] new dump_dir --- .github/workflows/train.yml | 4 ++-- src/configs/sys_config.py | 3 ++- src/configs/sys_config_test.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index a9e67ecb..3d7e86fc 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -4,8 +4,8 @@ on: workflow_dispatch: push: branches: - - main - # - "*" + # - main + - "*" pull_request: branches: - main diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index aa9c6854..571b0970 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -158,7 +158,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "../../../../../../../home/" +# DUMP_DIR = "../../../../../../../home/" +DUMP_DIR = "./" num_users = 3 mpi_system_config: ConfigType = { diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index 6616328b..9917e68f 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -80,7 +80,8 @@ def get_algo_configs( CIFAR10_DSET = "cifar10" CIAR10_DPATH = "./datasets/imgs/cifar10/" -DUMP_DIR = "../../../../../../../home/" +# DUMP_DIR = "../../../../../../../home/" +DUMP_DIR = "./" NUM_COLLABORATORS = 1 num_users = 4 From 9e3ea96ddd87f0f7e8cfa44c9633db45c71d6460 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Fri, 29 Nov 2024 20:15:49 -0500 Subject: [PATCH 17/19] remove send_status from proto --- src/utils/communication/grpc/comm.proto | 5 --- src/utils/communication/grpc/comm_pb2.py | 44 +++++++++---------- src/utils/communication/grpc/comm_pb2_grpc.py | 43 ------------------ 3 files changed, 21 insertions(+), 71 deletions(-) diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto index c69ade51..8f689c36 100644 --- a/src/utils/communication/grpc/comm.proto +++ b/src/utils/communication/grpc/comm.proto @@ -3,7 +3,6 @@ syntax = "proto3"; service CommunicationServer { - rpc send_status(Empty) returns (Status) {} rpc send_data (Data) returns (Empty) {} rpc send_model (Model) returns (Empty) {} rpc get_rank (Empty) returns (Rank) {} @@ -17,10 +16,6 @@ service CommunicationServer { message Empty {} -message Status{ - string message = 1; -} - message Model { bytes buffer = 1; } diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py index 43938394..a9b03cea 100644 --- a/src/utils/communication/grpc/comm_pb2.py +++ b/src/utils/communication/grpc/comm_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x19\n\x06Status\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xe3\x02\n\x13\x43ommunicationServer\x12 \n\x0bsend_status\x12\x06.Empty\x1a\x07.Status\"\x00\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xc1\x02\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,26 +25,24 @@ _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' _globals['_EMPTY']._serialized_start=14 _globals['_EMPTY']._serialized_end=21 - _globals['_STATUS']._serialized_start=23 - _globals['_STATUS']._serialized_end=48 - _globals['_MODEL']._serialized_start=50 - _globals['_MODEL']._serialized_end=73 - _globals['_DATA']._serialized_start=75 - _globals['_DATA']._serialized_end=116 - _globals['_RANK']._serialized_start=118 - _globals['_RANK']._serialized_end=138 - _globals['_ROUND']._serialized_start=140 - _globals['_ROUND']._serialized_end=162 - _globals['_PORT']._serialized_start=164 - _globals['_PORT']._serialized_end=184 - _globals['_PEERID']._serialized_start=186 - _globals['_PEERID']._serialized_end=248 - _globals['_PEERIDS']._serialized_start=250 - _globals['_PEERIDS']._serialized_end=357 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=302 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=357 - _globals['_QUORUM']._serialized_start=359 - _globals['_QUORUM']._serialized_end=383 - _globals['_COMMUNICATIONSERVER']._serialized_start=386 - _globals['_COMMUNICATIONSERVER']._serialized_end=741 + _globals['_MODEL']._serialized_start=23 + _globals['_MODEL']._serialized_end=46 + _globals['_DATA']._serialized_start=48 + _globals['_DATA']._serialized_end=89 + _globals['_RANK']._serialized_start=91 + _globals['_RANK']._serialized_end=111 + _globals['_ROUND']._serialized_start=113 + _globals['_ROUND']._serialized_end=135 + _globals['_PORT']._serialized_start=137 + _globals['_PORT']._serialized_end=157 + _globals['_PEERID']._serialized_start=159 + _globals['_PEERID']._serialized_end=221 + _globals['_PEERIDS']._serialized_start=223 + _globals['_PEERIDS']._serialized_end=330 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=275 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=330 + _globals['_QUORUM']._serialized_start=332 + _globals['_QUORUM']._serialized_end=356 + _globals['_COMMUNICATIONSERVER']._serialized_start=359 + _globals['_COMMUNICATIONSERVER']._serialized_end=680 # @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py index ea45534d..e0258f72 100644 --- a/src/utils/communication/grpc/comm_pb2_grpc.py +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -39,11 +39,6 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.send_status = channel.unary_unary( - '/CommunicationServer/send_status', - request_serializer=comm__pb2.Empty.SerializeToString, - response_deserializer=comm__pb2.Status.FromString, - _registered_method=True) self.send_data = channel.unary_unary( '/CommunicationServer/send_data', request_serializer=comm__pb2.Data.SerializeToString, @@ -94,12 +89,6 @@ def __init__(self, channel): class CommunicationServerServicer(object): """Missing associated documentation comment in .proto file.""" - def send_status(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def send_data(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -157,11 +146,6 @@ def send_finished(self, request, context): def add_CommunicationServerServicer_to_server(servicer, server): rpc_method_handlers = { - 'send_status': grpc.unary_unary_rpc_method_handler( - servicer.send_status, - request_deserializer=comm__pb2.Empty.FromString, - response_serializer=comm__pb2.Status.SerializeToString, - ), 'send_data': grpc.unary_unary_rpc_method_handler( servicer.send_data, request_deserializer=comm__pb2.Data.FromString, @@ -218,33 +202,6 @@ def add_CommunicationServerServicer_to_server(servicer, server): class CommunicationServer(object): """Missing associated documentation comment in .proto file.""" - @staticmethod - def send_status(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/CommunicationServer/send_status', - comm__pb2.Empty.SerializeToString, - comm__pb2.Status.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - @staticmethod def send_data(request, target, From 3d28db047e17312742de12a2ff9b2aa0fe2e27b9 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 11:03:33 -0500 Subject: [PATCH 18/19] changed dump_dir --- .vscode/settings.json | 9 +-------- src/configs/sys_config.py | 4 ++-- src/configs/sys_config_test.py | 2 +- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 745eb685..d6e26387 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,10 +1,3 @@ { - "python.analysis.typeCheckingMode": "strict", - "sshfs.configs": [ - { - "name": "matlaber", - "host": "matlaber7.media.mit.edu", - "username": "kle" - } - ] + "python.analysis.typeCheckingMode": "strict" } diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 71c952b5..c0b40314 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -159,7 +159,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): NUM_COLLABORATORS = 1 # DUMP_DIR = "../../../../../../../home/" -DUMP_DIR = "./" +DUMP_DIR = "/tmp/" num_users = 3 mpi_system_config: ConfigType = { @@ -170,7 +170,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, - # "seed": 32, + "seed": 32, "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index 9917e68f..f3575419 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -81,7 +81,7 @@ def get_algo_configs( CIAR10_DPATH = "./datasets/imgs/cifar10/" # DUMP_DIR = "../../../../../../../home/" -DUMP_DIR = "./" +DUMP_DIR = "/tmp/" NUM_COLLABORATORS = 1 num_users = 4 From ddbdbb1a3c19829426a330cdf151c395d144f79f Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 11:06:40 -0500 Subject: [PATCH 19/19] small changes --- src/configs/sys_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index c0b40314..2e7e0438 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -171,7 +171,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, "seed": 32, - "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing):