Skip to content

Commit

Permalink
modified algos but comms isn't working
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-yuan committed Aug 24, 2024
1 parent 4cb03dc commit 3350527
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
34 changes: 17 additions & 17 deletions src/algos/def_kt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import copy
import random
from collections import OrderedDict
from typing import List
from typing import Any, Dict, List
from torch import Tensor
from utils.communication.comm_utils import CommunicationManager
import torch.nn as nn

from algos.base_class import BaseClient, BaseServer
Expand All @@ -28,8 +29,8 @@ class DefKTClient(BaseClient):
"""
Client class for DefKT (Deep Mutual Learning with Knowledge Transfer)
"""
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)
self.config = config
self.tag = CommProtocol
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
Expand Down Expand Up @@ -112,15 +113,15 @@ def send_representations(self, representation):
Send the model representations to the clients
"""
for client_node in self.clients:
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
self.comm_utils.send(client_node, representation) #, self.tag.UPDATES)
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def single_round(self, self_repr):
"""
Runs a single training round
"""
print("Node 1 waiting for all clients to finish")
reprs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.DONE)
reprs = self.comm_utils.receive(self.clients) #, self.tag.DONE)
reprs.append(self_repr)
print(f"Node 1 received {len(reprs)} clients' weights")
avg_wts = self.aggregate(reprs)
Expand Down Expand Up @@ -151,26 +152,22 @@ def run_protocol(self):
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
for epoch in range(start_epochs, total_epochs):
status = self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
status = self.comm_utils.receive(src=0) #, tag=self.tag.START)
self.assign_own_status(status)
if self.status == "teacher":
self.local_train()
self_repr = self.get_representation()
self.comm_utils.send_signal(
dest=self.pair_id, data=self_repr, tag=self.tag.DONE
)
self.comm_utils.send(dest=self.pair_id, data=self_repr)#tag=self.tag.DONE
print(f"Node {self.node_id} sent repr to student node {self.pair_id}")
elif self.status == "student":
teacher_repr = self.comm_utils.wait_for_signal(
src=self.pair_id, tag=self.tag.DONE
)
teacher_repr = self.comm_utils.receive(src=self.pair_id) #, tag=self.tag.DONE)
print(f"Node {self.node_id} received repr from teacher node {self.pair_id}")
self.deep_mutual_train(teacher_repr)
else:
print(f"Node {self.node_id} do nothing")
acc = self.local_test()
print(f"Node {self.node_id} test_acc:{acc:.4f}")
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)
self.comm_utils.send(dest=0, data=acc) #, tag=self.tag.FINISH)


class DefKTServer(BaseServer):
Expand All @@ -190,7 +187,7 @@ def send_representations(self, representations):
Send the model representations to the clients
"""
for client_node in self.users:
self.comm_utils.send_signal(client_node, representations, self.tag.UPDATES)
self.comm_utils.send(client_node, representations) #, self.tag.UPDATES)
self.log_utils.log_console(
f"Server sent {len(representations)} representations to node {client_node}"
)
Expand Down Expand Up @@ -232,8 +229,11 @@ def single_round(self):
self.log_utils.log_console(
f"Server sending status from {self.node_id} to {client_node}"
)
self.comm_utils.send_signal(
dest=client_node, data=[teachers, students], tag=self.tag.START
# self.comm_utils.send_signal(
# dest=client_node, data=[teachers, students], tag=self.tag.START
# )
self.comm_utils.send(
dest=client_node, data=[teachers, students]
)

def run_protocol(self):
Expand All @@ -246,5 +246,5 @@ def run_protocol(self):
for epoch in range(start_epochs, total_epochs):
self.log_utils.log_console(f"Starting round {epoch}")
self.single_round()
accs = self.comm_utils.wait_for_all_clients(self.users, self.tag.FINISH)
accs = self.comm_utils.receive(self.users) #, self.tag.FINISH)
self.log_utils.log_console(f"Round {epoch} done; acc {accs}")
6 changes: 4 additions & 2 deletions src/algos/fl_isolated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from algos.base_class import BaseClient, BaseServer
from utils.stats_utils import from_rounds_stats_per_client_per_round_to_dict_arrays
from typing import Any, Dict, List
from utils.communication.comm_utils import CommunicationManager


class CommProtocol(object):
Expand All @@ -12,8 +14,8 @@ class CommProtocol(object):


class FedIsoClient(BaseClient):
def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
super().__init__(config, comm_utils)
self.config = config
self.tag = CommProtocol
self.model_save_path = "{}/saved_models/node_{}.pt".format(
Expand Down

0 comments on commit 3350527

Please sign in to comment.