Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log communication cost #120

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def run_protocol(self):
self.local_round_done()

self.receive_and_aggregate()

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()

self.log_metrics(stats=stats, iteration=round)


Expand Down Expand Up @@ -156,5 +159,6 @@ def run_protocol(self):
for round in range(start_rounds, total_rounds):
self.local_round_done()
self.single_round()
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
stats["test_loss"], stats["test_acc"], stats["test_time"] = self.test()
self.log_metrics(stats=stats, iteration=round)
self.log_metrics(stats=stats, iteration=round)
4 changes: 3 additions & 1 deletion src/algos/fl_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def run_protocol(self) -> None:

self.receive_and_aggregate(neighbors)

stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()

# evaluate the model on the test data
# Inside FedStaticNode.run_protocol()
stats["test_loss"], stats["test_acc"] = self.local_test()
Expand All @@ -73,4 +75,4 @@ def __init__(
pass

def run_protocol(self) -> None:
pass
pass
3 changes: 3 additions & 0 deletions src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ def finalize(self):

def set_is_working(self, is_working: bool):
self.comm.set_is_working(is_working)

def get_comm_cost(self):
return self.comm.get_comm_cost()
31 changes: 31 additions & 0 deletions src/utils/communication/grpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import time
import socket
import functools
from typing import Any, Callable, Dict, List, OrderedDict, Union, TYPE_CHECKING
from urllib.parse import unquote
import grpc # type: ignore
Expand Down Expand Up @@ -95,18 +96,37 @@ def __init__(self, super_node_host: str):
{0: {"rank": 0, "port": port, "ip": ip}}
)
self.is_working = True
self.communication_cost_received: int = 0
self.communication_cost_sent: int = 0

def get_comm_cost(self):
with self.lock:
return self.communication_cost_received, self.communication_cost_sent

def set_is_working(self, is_working: bool):
with self.lock:
self.is_working = is_working

def update_communication_cost(func):
def wrapper(self, request, context):
down_cost = request.ByteSize()
return_data = func(self, request, context)
up_cost = return_data.ByteSize()
with self.lock:
self.communication_cost_received += down_cost
self.communication_cost_sent += up_cost
return return_data
return wrapper

def register_self(self, obj: "BaseNode"):
self.base_node = obj

@update_communication_cost
def send_data(self, request, context) -> comm_pb2.Empty: # type: ignore
self.received_data.put(deserialize_model(request.model.buffer)) # type: ignore
return comm_pb2.Empty() # type: ignore

@update_communication_cost
def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank | None:
try:
with self.lock:
Expand All @@ -121,6 +141,7 @@ def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> co
except Exception as e:
context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore

@update_communication_cost
def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Model | None:
if not self.base_node:
context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore
Expand All @@ -133,6 +154,7 @@ def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> c
model = comm_pb2.Model(buffer=EMPTY_MODEL_TAG)
return model

@update_communication_cost
def get_current_round(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Round | None:
if not self.base_node:
context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore
Expand Down Expand Up @@ -239,6 +261,8 @@ def register(self):
"""
def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int:
rank_data = stub.get_rank(comm_pb2.Empty()) # type: ignore
with self.servicer.lock:
self.servicer.communication_cost_received += rank_data.ByteSize()
return rank_data.rank # type: ignore

self.rank = self.recv_with_retries(self.super_node_host, callback_fn)
Expand Down Expand Up @@ -352,6 +376,8 @@ def wait_until_rounds_match(self, id: int):
self_round = self.servicer.base_node.get_local_rounds()
def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int:
round = stub.get_current_round(comm_pb2.Empty()) # type: ignore
with self.servicer.lock:
self.servicer.communication_cost_received += round.ByteSize()
return round.round # type: ignore

while True:
Expand Down Expand Up @@ -386,6 +412,8 @@ def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> OrderedDict[str,
model = stub.get_model(comm_pb2.Empty()) # type: ignore
if model.buffer == EMPTY_MODEL_TAG:
return OrderedDict()
with self.servicer.lock:
self.servicer.communication_cost_received += model.ByteSize()
return deserialize_model(model.buffer) # type: ignore

for id in node_ids:
Expand Down Expand Up @@ -424,6 +452,9 @@ def get_num_finished(self) -> int:
def set_is_working(self, is_working: bool):
self.servicer.set_is_working(is_working)

def get_comm_cost(self):
return self.servicer.get_comm_cost()

def finalize(self):
# 1. All nodes send finished to the super node
# 2. super node will wait for all nodes to send finished
Expand Down
Loading