diff --git a/src/algos/def_kt.py b/src/algos/def_kt.py index 0dcfe47..d06433c 100644 --- a/src/algos/def_kt.py +++ b/src/algos/def_kt.py @@ -6,8 +6,9 @@ import copy import random from collections import OrderedDict -from typing import List +from typing import Any, Dict, List from torch import Tensor +from utils.communication.comm_utils import CommunicationManager import torch.nn as nn from algos.base_class import BaseClient, BaseServer @@ -28,8 +29,8 @@ class DefKTClient(BaseClient): """ Client class for DefKT (Deep Mutual Learning with Knowledge Transfer) """ - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt" @@ -112,7 +113,7 @@ def send_representations(self, representation): Send the model representations to the clients """ for client_node in self.clients: - self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES) + self.comm_utils.send(client_node, representation) #, self.tag.UPDATES) print(f"Node 1 sent average weight to {len(self.clients)} nodes") def single_round(self, self_repr): @@ -120,7 +121,7 @@ def single_round(self, self_repr): Runs a single training round """ print("Node 1 waiting for all clients to finish") - reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.DONE) + reprs = self.comm_utils.receive(self.clients) #, self.tag.DONE) reprs.append(self_repr) print(f"Node 1 received {len(reprs)} clients' weights") avg_wts = self.aggregate(reprs) @@ -151,26 +152,22 @@ def run_protocol(self): start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] for epoch in range(start_epochs, total_epochs): - status = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START) + status = self.comm_utils.receive(src=0) #, tag=self.tag.START) self.assign_own_status(status) if self.status == "teacher": self.local_train() self_repr = self.get_representation() - self.comm_utils.send_signal( - dest=self.pair_id, data=self_repr, tag=self.tag.DONE - ) + self.comm_utils.send(dest=self.pair_id, data=self_repr)#tag=self.tag.DONE print(f"Node {self.node_id} sent repr to student node {self.pair_id}") elif self.status == "student": - teacher_repr = self.comm_utils.wait_for_signal( - src=self.pair_id, tag=self.tag.DONE - ) + teacher_repr = self.comm_utils.receive(src=self.pair_id) #, tag=self.tag.DONE) print(f"Node {self.node_id} received repr from teacher node {self.pair_id}") self.deep_mutual_train(teacher_repr) else: print(f"Node {self.node_id} do nothing") acc = self.local_test() print(f"Node {self.node_id} test_acc:{acc:.4f}") - self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH) + self.comm_utils.send(dest=0, data=acc) #, tag=self.tag.FINISH) class DefKTServer(BaseServer): @@ -190,7 +187,7 @@ def send_representations(self, representations): Send the model representations to the clients """ for client_node in self.users: - self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES) + self.comm_utils.send(client_node, representations) #, self.tag.UPDATES) self.log_utils.log_console( f"Server sent {len(representations)} representations to node {client_node}" ) @@ -232,8 +229,11 @@ def single_round(self): self.log_utils.log_console( f"Server sending status from {self.node_id} to {client_node}" ) - self.comm_utils.send_signal( - dest=client_node, data=[teachers, students], tag=self.tag.START + # self.comm_utils.send_signal( + # dest=client_node, data=[teachers, students], tag=self.tag.START + # ) + self.comm_utils.send( + dest=client_node, data=[teachers, students] ) def run_protocol(self): @@ -246,5 +246,5 @@ def run_protocol(self): for epoch in range(start_epochs, total_epochs): self.log_utils.log_console(f"Starting round {epoch}") self.single_round() - accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH) + accs = self.comm_utils.receive(self.users) #, self.tag.FINISH) self.log_utils.log_console(f"Round {epoch} done; acc {accs}") diff --git a/src/algos/fl_isolated.py b/src/algos/fl_isolated.py index 059430e..a6441c8 100644 --- a/src/algos/fl_isolated.py +++ b/src/algos/fl_isolated.py @@ -1,5 +1,7 @@ from algos.base_class import BaseClient, BaseServer from utils.stats_utils import from_rounds_stats_per_client_per_round_to_dict_arrays +from typing import Any, Dict, List +from utils.communication.comm_utils import CommunicationManager class CommProtocol(object): @@ -12,8 +14,8 @@ class CommProtocol(object): class FedIsoClient(BaseClient): - def __init__(self, config) -> None: - super().__init__(config) + def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None: + super().__init__(config, comm_utils) self.config = config self.tag = CommProtocol self.model_save_path = "{}/saved_models/node_{}.pt".format(