Skip to content

Commit

Permalink
Fixed typos in util files and extended algos type hinting (#79)
Browse files Browse the repository at this point in the history
Co-authored-by: generic-account <Gordon Lichtstein>
Co-authored-by: Abhishek Singh <abhishek.s14@iiits.in>
  • Loading branch information
generic-account and tremblerz authored Sep 14, 2024
1 parent 927fd35 commit ae17575
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 96 deletions.
121 changes: 66 additions & 55 deletions src/algos/DisPFL.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
import random
from collections import OrderedDict
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -21,34 +21,42 @@ class CommProtocol:
Communication protocol tags for the server and clients.
"""

DONE = 0 # Used to signal the server that the client is done with local training
START = 1 # Used to signal by the server to start the current round
UPDATES = 2 # Used to send the updates from the server to the clients
LAST_ROUND = 3
SHARE_MASKS = 4
SHARE_WEIGHTS = 5
FINISH = 6 # Used to signal the server to finish the current round
DONE: int = 0 # Used to signal the server that the client is done with local training
START: int = 1 # Used to signal by the server to start the current round
UPDATES: int = 2 # Used to send the updates from the server to the clients
LAST_ROUND: int = 3
SHARE_MASKS: int = 4
SHARE_WEIGHTS: int = 5
FINISH: int = 6 # Used to signal the server to finish the current round


class DisPFLClient(BaseClient):
"""
Client class for DisPFL (Distributed Personalized Federated Learning).
"""
def __init__(self, config) -> None:

def __init__(self, config: Dict[str, Any]) -> None:
super().__init__(config)
self.params: Optional[Dict[str, Tensor]] = None
self.mask: Optional[OrderedDict[str, Tensor]] = None
self.index: Optional[int] = None
self.repr: Optional[OrderedDict[str, Tensor]] = None
self.config = config
self.tag = CommProtocol
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.dense_ratio = self.config["dense_ratio"]
self.anneal_factor = self.config["anneal_factor"]
self.dis_gradient_check = self.config["dis_gradient_check"]
self.server_node = 1 # leader node
self.num_users = config["num_users"]
self.neighbors = list(range(self.num_users))
self.model_save_path = (
f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
)
self.dense_ratio: float = self.config["dense_ratio"]
self.anneal_factor: float = self.config["anneal_factor"]
self.dis_gradient_check: bool = self.config["dis_gradient_check"]
self.server_node: int = 1 # leader node
self.num_users: int = config["num_users"]
self.neighbors: List[int] = list(range(self.num_users))

if self.node_id == 1:
self.clients = list(range(2, self.num_users + 1))

def local_train(self):
def local_train(self) -> None:
"""
Train the model locally.
"""
Expand All @@ -57,20 +65,16 @@ def local_train(self):
)
print(f"Node{self.node_id} train loss: {loss}, train acc: {acc}")

def local_test(self, **kwargs):
def local_test(self, **kwargs: Any) -> Tuple[float, float]:
"""
Test the model locally, not to be used in the traditional FedAvg.
"""
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
# TODO save the model if the accuracy is better than the best accuracy so far
# if acc > self.best_acc:
# self.best_acc = acc
# self.model_utils.save_model(self.model, self.model_save_path)
return test_loss, acc

def get_trainable_params(self):
def get_trainable_params(self) -> Dict[str, Tensor]:
param_dict = {}
for name, param in self.model.named_parameters():
param_dict[name] = param
Expand All @@ -82,13 +86,13 @@ def get_representation(self) -> OrderedDict[str, Tensor]:
"""
return self.model.state_dict()

def set_representation(self, representation: OrderedDict[str, Tensor]):
def set_representation(self, representation: OrderedDict[str, Tensor]) -> None:
"""
Set the model weights.
"""
self.model.load_state_dict(representation)

def fire_mask(self, masks, round_num, total_round):
def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]:
"""
Fire mask method for model pruning.
"""
Expand All @@ -97,7 +101,7 @@ def fire_mask(self, masks, round_num, total_round):
self.anneal_factor / 2 * (1 + np.cos((round_num * np.pi) / total_round))
)
new_masks = copy.deepcopy(masks)
num_remove = {}
num_remove: Dict[str, int] = {}
for name in masks:
num_non_zeros = torch.sum(masks[name])
num_remove[name] = math.ceil(drop_ratio * num_non_zeros)
Expand All @@ -110,7 +114,7 @@ def fire_mask(self, masks, round_num, total_round):
new_masks[name].view(-1)[idx[: num_remove[name]]] = 0
return new_masks, num_remove

def regrow_mask(self, masks, num_remove, gradient=None):
def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int], gradient: Optional[Dict[str, Tensor]] = None) -> OrderedDict[str, Tensor]:
"""
Regrow mask method for model pruning.
"""
Expand Down Expand Up @@ -138,7 +142,7 @@ def regrow_mask(self, masks, num_remove, gradient=None):
new_masks[name].view(-1)[idx] = 1
return new_masks

def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd):
def aggregate(self, nei_indexes: List[int], weights_lstrnd: List[OrderedDict[str, Tensor]], masks_lstrnd: List[OrderedDict[str, Tensor]]) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]:
"""
Aggregate the model weights.
"""
Expand Down Expand Up @@ -171,21 +175,21 @@ def aggregate(self, nei_indexes, weights_lstrnd, masks_lstrnd):
w_tmp[name] = w_tmp[name] * self.mask[name].to(self.device)
return w_tmp, w_p_g

def send_representations(self, representation):
def send_representations(self, representation: OrderedDict[str, Tensor]) -> None:
"""
Set the model.
"""
for client_node in self.clients:
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5):
def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[str]] = None, distribution: str = "ERK", sparse: float = 0.5) -> Dict[str, float]:
"""
Calculate sparsities for model pruning.
"""
if tabu is None:
tabu = []
sparsities = {}
sparsities: Dict[str, float] = {}
if distribution == "uniform":
for name in params:
if name not in tabu:
Expand Down Expand Up @@ -239,22 +243,22 @@ def calculate_sparsities(self, params, tabu=None, distribution="ERK", sparse=0.5
sparsities[name] = 1 - epsilon * raw_probabilities[name]
return sparsities

def init_masks(self, params, sparsities):
def init_masks(self, params: Dict[str, Tensor], sparsities: Dict[str, float]) -> OrderedDict[str, Tensor]:
"""
Initialize masks for model pruning.
"""
masks = OrderedDict()
for name in params:
masks[name] = zeros_like(params[name])
dense_numel = int((1 - sparsities[name]) * numel(masks[name]))
dense_numel = int((1 - sparsities[name]) * masks[name].numel())
if dense_numel > 0:
temp = masks[name].view(-1)
perm = randperm(len(temp))
perm = perm[:dense_numel]
temp[perm] = 1
return masks

def screen_gradient(self):
def screen_gradient(self) -> Dict[str, Tensor]:
"""
Screen gradient method for model pruning.
"""
Expand All @@ -273,7 +277,7 @@ def screen_gradient(self):

return gradient

def hamming_distance(self, mask_a, mask_b):
def hamming_distance(self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor]) -> Tuple[int, int]:
"""
Calculate the Hamming distance between two masks.
"""
Expand All @@ -288,20 +292,20 @@ def hamming_distance(self, mask_a, mask_b):

def _benefit_choose(
self,
round_idx,
cur_clnt,
client_num_in_total,
client_num_per_round,
dist_local,
total_dist,
cs=False,
active_ths_rnd=None,
):
round_idx: int, # pylint: disable=unused-argument
cur_clnt: int,
client_num_in_total: int,
client_num_per_round: int,
dist_local: Optional[np.ndarray], # pylint: disable=unused-argument
total_dist: Optional[np.ndarray], # pylint: disable=unused-argument
cs: bool = False,
active_ths_rnd: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Benefit choose method for client selection.
"""
if client_num_in_total == client_num_per_round:
client_indexes = list(range(client_num_in_total))
client_indexes = np.array(list(range(client_num_in_total)))
return client_indexes

if cs == "random":
Expand All @@ -326,7 +330,7 @@ def _benefit_choose(
)
return client_indexes

def model_difference(self, model_a, model_b):
def model_difference(self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor]) -> Tensor:
"""
Calculate the difference between two models.
"""
Expand All @@ -335,7 +339,7 @@ def model_difference(self, model_a, model_b):
)
return diff

def run_protocol(self):
def run_protocol(self) -> None:
"""
Runs the entire training protocol.
"""
Expand Down Expand Up @@ -446,22 +450,29 @@ class DisPFLServer(BaseServer):
"""
Server class for DisPFL (Distributed Personalized Federated Learning).
"""
def __init__(self, config) -> None:

def __init__(self, config: Dict[str, Any]) -> None:
super().__init__(config)
self.best_acc: float = 0
self.round: int = 0
self.masks: Any = 0
self.reprs: Any = 0
self.config = config
self.set_model_parameters(config)
self.tag = CommProtocol
self.model_save_path = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.dense_ratio = self.config["dense_ratio"]
self.num_users = self.config["num_users"]
self.model_save_path = (
f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
)
self.dense_ratio: float = self.config["dense_ratio"]
self.num_users: int = self.config["num_users"]

def get_representation(self) -> OrderedDict[str, Tensor]:
"""
Share the model weights.
"""
return self.model.state_dict()

def send_representations(self, representations):
def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None:
"""
Set the model.
"""
Expand All @@ -484,7 +495,7 @@ def test(self) -> float:
self.model_utils.save_model(self.model, self.model_save_path)
return acc

def single_round(self, epoch, active_ths_rnd):
def single_round(self, epoch: int, active_ths_rnd: np.ndarray) -> None:
"""
Runs the whole training procedure.
"""
Expand All @@ -509,13 +520,13 @@ def single_round(self, epoch, active_ths_rnd):
self.users, self.tag.SHARE_WEIGHTS
)

def get_trainable_params(self):
def get_trainable_params(self) -> Dict[str, Tensor]:
param_dict = {}
for name, param in self.model.named_parameters():
param_dict[name] = param
return param_dict

def run_protocol(self):
def run_protocol(self) -> None:
"""
Runs the entire training protocol.
"""
Expand Down
Loading

0 comments on commit ae17575

Please sign in to comment.