From d00c9e2f9a83b3432c7812dcfa4a4893bd7eae05 Mon Sep 17 00:00:00 2001 From: samvit-d Date: Mon, 9 Sep 2024 21:08:57 -0400 Subject: [PATCH] added type hints in comm_utils.py and community_utils.py (#76) * added type hints in comm_utils.py and community_utils.py * remove space --------- Co-authored-by: Samvit Das <[samvit.das@gmail.com]> Co-authored-by: Abhishek Singh --- src/utils/comm_utils.py | 2 +- src/utils/distrib_utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/utils/comm_utils.py b/src/utils/comm_utils.py index 832d948..2410b34 100644 --- a/src/utils/comm_utils.py +++ b/src/utils/comm_utils.py @@ -1,5 +1,5 @@ from mpi4py import MPI -from typing import List, Optional, Any +from typing import Any, Optional, List class CommUtils: diff --git a/src/utils/distrib_utils.py b/src/utils/distrib_utils.py index 93d21de..094abcc 100644 --- a/src/utils/distrib_utils.py +++ b/src/utils/distrib_utils.py @@ -9,6 +9,7 @@ from torch.utils.data import Subset, DataLoader from resnet import ResNet34, ResNet18, ResNet50 from utils.data_utils import extr_noniid +from typing import Dict, Any def load_weights(model_dir: str, model: nn.Module, client_num: int): """ @@ -31,7 +32,7 @@ class ServerObj: """ Server object for federated learning. """ - def __init__(self, config, obj, rank) -> None: + def __init__(self, config: Dict[str, Any], obj: Dict[str, Any], rank: int) -> None: self.num_users = config["num_users"] self.samples_per_user = config["samples_per_user"] self.device = obj["device"] @@ -41,7 +42,7 @@ def __init__(self, config, obj, rank) -> None: num_channels = obj["dset_obj"].num_channels self.test_loader = DataLoader(test_dataset, batch_size=batch_size) - model_dict = { + model_dict: Dict[str, Any] = { "ResNet18": ResNet18(num_channels), "ResNet34": ResNet34(num_channels), "ResNet50": ResNet50(num_channels) @@ -53,7 +54,7 @@ class ClientObj: """ Client object for federated learning. """ - def __init__(self, config, obj, rank) -> None: + def __init__(self, config: Dict[str, Any], obj: Dict[str, Any], rank: int) -> None: self.num_users = config["num_users"] self.samples_per_user = config["samples_per_user"] self.device = obj["device"]