Skip to content

Commit

Permalink
added type hints in comm_utils.py and community_utils.py (#76)
Browse files Browse the repository at this point in the history
* added type hints in comm_utils.py and community_utils.py

* remove space

---------

Co-authored-by: Samvit Das <[samvit.das@gmail.com]>
Co-authored-by: Abhishek Singh <abhishek.s14@iiits.in>
  • Loading branch information
3 people authored Sep 10, 2024
1 parent 7140928 commit d00c9e2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/utils/comm_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mpi4py import MPI
from typing import List, Optional, Any
from typing import Any, Optional, List


class CommUtils:
Expand Down
7 changes: 4 additions & 3 deletions src/utils/distrib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import Subset, DataLoader
from resnet import ResNet34, ResNet18, ResNet50
from utils.data_utils import extr_noniid
from typing import Dict, Any

def load_weights(model_dir: str, model: nn.Module, client_num: int):
"""
Expand All @@ -31,7 +32,7 @@ class ServerObj:
"""
Server object for federated learning.
"""
def __init__(self, config, obj, rank) -> None:
def __init__(self, config: Dict[str, Any], obj: Dict[str, Any], rank: int) -> None:
self.num_users = config["num_users"]
self.samples_per_user = config["samples_per_user"]
self.device = obj["device"]
Expand All @@ -41,7 +42,7 @@ def __init__(self, config, obj, rank) -> None:
num_channels = obj["dset_obj"].num_channels

self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
model_dict = {
model_dict: Dict[str, Any] = {
"ResNet18": ResNet18(num_channels),
"ResNet34": ResNet34(num_channels),
"ResNet50": ResNet50(num_channels)
Expand All @@ -53,7 +54,7 @@ class ClientObj:
"""
Client object for federated learning.
"""
def __init__(self, config, obj, rank) -> None:
def __init__(self, config: Dict[str, Any], obj: Dict[str, Any], rank: int) -> None:
self.num_users = config["num_users"]
self.samples_per_user = config["samples_per_user"]
self.device = obj["device"]
Expand Down

0 comments on commit d00c9e2

Please sign in to comment.