Skip to content

Commit

Permalink
Efficient automated GitHub testing (#153)
Browse files Browse the repository at this point in the history
* added MPI Communication class

* added send thread, merged 2 classes

* improved comments

* testing mpi, model weights not acquired

* mpi works, occassional deadlock issue

* merged send and listener threads

* added super init to fl_static server

* logging dataset loading

* reduced test size during workflow testing

* workflow debugging

* workflow debugging

* workflow run on push to main only

* using requirements cpu

* using test_samples_per_user to reduce test set

* download data in server

* fedstatic works with testing

* change dump_dir

* code cleanup

* code cleanup
  • Loading branch information
kathrynle20 authored Dec 9, 2024
1 parent b76f33c commit 7962bce
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 23 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/train.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ jobs:
- name: Run test
run: |
cd src
# chmod +x ./configs/algo_config_test.py
echo "starting main grpc"
python main_grpc.py -n 4 -host localhost
python main_grpc.py -n 4 -host localhost -dev True
echo "starting main"
python main.py -super true -s "./configs/sys_config_test.py"
python main.py -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" -super true
echo "done"
# further checks:
Expand Down
12 changes: 10 additions & 2 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None:
)
else:
raise ValueError(f"Unknown community type: {community_type}.")
if self.node_id == 0:
self.log_utils.log_console(f"Communities: {self.communities}")
# if self.node_id == 0:
# self.log_utils.log_console(f"Communities: {self.communities}")

def local_round_done(self) -> None:
self.round += 1
Expand Down Expand Up @@ -686,6 +686,14 @@ def is_same_dest(dset):
if self.dset.startswith("domainnet"):
test_dset = CacheDataset(test_dset)

# reduce test_dset size
if config.get("test_samples_per_user", 0) != 0:
print(f"Reducing test size to {config.get('test_samples_per_user', 0)}")
reduced_test_size = config.get("test_samples_per_user", 0)
indices = np.random.choice(len(test_dset), reduced_test_size, replace=False)
test_dset = Subset(test_dset, indices)
print(f"test_dset size: {len(test_dset)}")

self._test_loader = DataLoader(test_dset, batch_size=batch_size)
# TODO: fix print_data_summary
# self.print_data_summary(train_dset, test_dset, val_dset=val_dset)
Expand Down
15 changes: 14 additions & 1 deletion src/algos/fl_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from algos.base_class import BaseFedAvgClient
from algos.topologies.collections import select_topology
from utils.data_utils import get_dataset

class FedStaticNode(BaseFedAvgClient):
"""
Expand Down Expand Up @@ -71,7 +72,19 @@ class FedStaticServer(BaseFedAvgClient):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
pass
self.comm_utils = comm_utils
self.node_id = self.comm_utils.get_rank()
self.comm_utils.register_node(self)
self.is_working = True
if isinstance(config["dset"], dict):
if self.node_id != 0:
config["dset"].pop("0") # type: ignore
self.dset = str(config["dset"][str(self.node_id)]) # type: ignore
config["dpath"] = config["dpath"][self.dset]
else:
self.dset = config["dset"]
print(f"Node {self.node_id} getting dset at {self.dset}")
self.dset_obj = get_dataset(self.dset, dpath=config["dpath"])

def run_protocol(self) -> None:
pass
20 changes: 10 additions & 10 deletions src/configs/algo_config_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from utils.types import ConfigType

# fedstatic: ConfigType = {
# # Collaboration setup
# "algo": "fedstatic",
# "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
# "rounds": 1,
fedstatic: ConfigType = {
# Collaboration setup
"algo": "fedstatic",
"topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore
"rounds": 1,

# # Model parameters
# "model": "resnet10",
# "model_lr": 3e-4,
# "batch_size": 256,
# }
# Model parameters
"model": "resnet10",
"model_lr": 3e-4,
"batch_size": 256,
}

traditional_fl: ConfigType = {
# Collaboration setup
Expand Down
2 changes: 0 additions & 2 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
CIAR10_DPATH = "./datasets/imgs/cifar10/"

NUM_COLLABORATORS = 1
# DUMP_DIR = "../../../../../../../home/"
DUMP_DIR = "/tmp/"

num_users = 3
Expand Down Expand Up @@ -391,4 +390,3 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):

current_config = grpc_system_config
# current_config = mpi_system_config

6 changes: 3 additions & 3 deletions src/configs/sys_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from utils.types import ConfigType

from .algo_config_test import (
traditional_fl
traditional_fl,
fedstatic
)

def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]:
Expand Down Expand Up @@ -80,7 +81,6 @@ def get_algo_configs(
CIFAR10_DSET = "cifar10"
CIAR10_DPATH = "./datasets/imgs/cifar10/"

# DUMP_DIR = "../../../../../../../home/"
DUMP_DIR = "/tmp/"

NUM_COLLABORATORS = 1
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_algo_configs(
"seed": 2,
"device_ids": get_device_ids(num_users, gpu_ids),
# "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore
"algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore
# "samples_per_user": 50000 // num_users, # distributed equally
"samples_per_user": 100,
"train_label_distribution": "non_iid",
Expand Down
9 changes: 9 additions & 0 deletions src/main_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@
help=f"host address of the nodes",
)

parser.add_argument(
"-dev",
nargs="?",
type=bool,
help=f"whether or not development testing",
)

args : argparse.Namespace = parser.parse_args()

# Command for opening each process
command_list: List[str] = ["python", "main.py", "-host", args.host]
if args.dev == True:
command_list: List[str] = ["python", "main.py", "-b", "./configs/algo_config_test.py", "-s", "./configs/sys_config_test.py", "-host", args.host]

# Start process for each user
for i in range(args.n):
Expand Down
2 changes: 1 addition & 1 deletion src/utils/communication/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def initialize(self):
def send_quorum(self) -> Any:
# return super().send_quorum(node_ids)
pass

def register_self(self, obj: "BaseNode"):
self.base_node = obj

Expand Down

0 comments on commit 7962bce

Please sign in to comment.