Skip to content

Commit

Permalink
fix formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamjajoo committed Oct 8, 2024
1 parent bfd153c commit e4cf4e5
Show file tree
Hide file tree
Showing 20 changed files with 1,110 additions and 592 deletions.
51 changes: 38 additions & 13 deletions src/algos/DisPFL.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import copy
import math
import random
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -21,7 +20,9 @@ class CommProtocol:
Communication protocol tags for the server and clients.
"""

DONE: int = 0 # Used to signal the server that the client is done with local training
DONE: int = (
0 # Used to signal the server that the client is done with local training
)
START: int = 1 # Used to signal by the server to start the current round
UPDATES: int = 2 # Used to send the updates from the server to the clients
LAST_ROUND: int = 3
Expand Down Expand Up @@ -92,7 +93,9 @@ def set_representation(self, representation: OrderedDict[str, Tensor]) -> None:
"""
self.model.load_state_dict(representation)

def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]:
def fire_mask(
self, masks: OrderedDict[str, Tensor], round_num: int, total_round: int
) -> Tuple[OrderedDict[str, Tensor], Dict[str, int]]:
"""
Fire mask method for model pruning.
"""
Expand All @@ -114,7 +117,12 @@ def fire_mask(self, masks: OrderedDict[str, Tensor], round_num: int, total_round
new_masks[name].view(-1)[idx[: num_remove[name]]] = 0
return new_masks, num_remove

def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int], gradient: Optional[Dict[str, Tensor]] = None) -> OrderedDict[str, Tensor]:
def regrow_mask(
self,
masks: OrderedDict[str, Tensor],
num_remove: Dict[str, int],
gradient: Optional[Dict[str, Tensor]] = None,
) -> OrderedDict[str, Tensor]:
"""
Regrow mask method for model pruning.
"""
Expand Down Expand Up @@ -142,7 +150,12 @@ def regrow_mask(self, masks: OrderedDict[str, Tensor], num_remove: Dict[str, int
new_masks[name].view(-1)[idx] = 1
return new_masks

def aggregate(self, nei_indexes: List[int], weights_lstrnd: List[OrderedDict[str, Tensor]], masks_lstrnd: List[OrderedDict[str, Tensor]]) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]:
def aggregate(
self,
nei_indexes: List[int],
weights_lstrnd: List[OrderedDict[str, Tensor]],
masks_lstrnd: List[OrderedDict[str, Tensor]],
) -> Tuple[OrderedDict[str, Tensor], OrderedDict[str, Tensor]]:
"""
Aggregate the model weights.
"""
Expand Down Expand Up @@ -183,7 +196,13 @@ def send_representations(self, representation: OrderedDict[str, Tensor]) -> None
self.comm_utils.send_signal(client_node, representation, self.tag.UPDATES)
print(f"Node 1 sent average weight to {len(self.clients)} nodes")

def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[str]] = None, distribution: str = "ERK", sparse: float = 0.5) -> Dict[str, float]:
def calculate_sparsities(
self,
params: Dict[str, Tensor],
tabu: Optional[List[str]] = None,
distribution: str = "ERK",
sparse: float = 0.5,
) -> Dict[str, float]:
"""
Calculate sparsities for model pruning.
"""
Expand Down Expand Up @@ -243,7 +262,9 @@ def calculate_sparsities(self, params: Dict[str, Tensor], tabu: Optional[List[st
sparsities[name] = 1 - epsilon * raw_probabilities[name]
return sparsities

def init_masks(self, params: Dict[str, Tensor], sparsities: Dict[str, float]) -> OrderedDict[str, Tensor]:
def init_masks(
self, params: Dict[str, Tensor], sparsities: Dict[str, float]
) -> OrderedDict[str, Tensor]:
"""
Initialize masks for model pruning.
"""
Expand Down Expand Up @@ -277,7 +298,9 @@ def screen_gradient(self) -> Dict[str, Tensor]:

return gradient

def hamming_distance(self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor]) -> Tuple[int, int]:
def hamming_distance(
self, mask_a: OrderedDict[str, Tensor], mask_b: OrderedDict[str, Tensor]
) -> Tuple[int, int]:
"""
Calculate the Hamming distance between two masks.
"""
Expand Down Expand Up @@ -330,7 +353,9 @@ def _benefit_choose(
)
return client_indexes

def model_difference(self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor]) -> Tensor:
def model_difference(
self, model_a: OrderedDict[str, Tensor], model_b: OrderedDict[str, Tensor]
) -> Tensor:
"""
Calculate the difference between two models.
"""
Expand Down Expand Up @@ -388,9 +413,7 @@ def run_protocol(self) -> None:
)
if self.num_users != self.config["neighbors"]:
nei_indexes = np.append(nei_indexes, self.index)
print(
f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}"
)
print(f"Node {self.index}'s neighbors index:{[i + 1 for i in nei_indexes]}")

for tmp_idx in nei_indexes:
if tmp_idx != self.index:
Expand Down Expand Up @@ -472,7 +495,9 @@ def get_representation(self) -> OrderedDict[str, Tensor]:
"""
return self.model.state_dict()

def send_representations(self, representations: Dict[int, OrderedDict[str, Tensor]]) -> None:
def send_representations(
self, representations: Dict[int, OrderedDict[str, Tensor]]
) -> None:
"""
Set the model.
"""
Expand Down
93 changes: 68 additions & 25 deletions src/algos/L2C.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict, defaultdict
from typing import Any, Optional, Union
from collections import defaultdict
from typing import Any
from utils.communication.comm_utils import CommunicationManager
import torch
import numpy as np
Expand All @@ -12,7 +12,9 @@


class L2CClient(BaseFedAvgClient):
def __init__(self, config: dict[str, Any], comm_utils: CommunicationManager) -> None:
def __init__(
self, config: dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
self.init_collab_weights()
self.sharing_mode: str = self.config["sharing"]
Expand Down Expand Up @@ -79,7 +81,9 @@ def filter_out_worse_neighbors(self, num_neighbors_to_keep: int) -> None:
weight_decay=self.config["alpha_weight_decay"],
)

def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]]) -> tuple[float, float]:
def learn_collab_weights(
self, models_update_wts: dict[int, dict[str, Tensor]]
) -> tuple[float, float]:
self.model.eval()
alpha_loss: float = 0
correct: int = 0
Expand All @@ -92,17 +96,23 @@ def learn_collab_weights(self, models_update_wts: dict[int, dict[str, Tensor]])
loss = self.loss_fn(output, target)
loss.backward()

grad_dict: dict[str, Tensor] = {k: v.grad for k, v in self.model.named_parameters()}
grad_dict: dict[str, Tensor] = {
k: v.grad for k, v in self.model.named_parameters()
}

collab_weights_grads: list[Tensor] = []
for id in self.neighbors_id_to_idx.keys():
cw_grad: Tensor = tensor(0.0)
for key in grad_dict.keys():
if key not in self.model_keys_to_ignore:
if self.sharing_mode == "updates":
cw_grad -= (models_update_wts[id][key] * grad_dict[key].cpu()).sum()
cw_grad -= (
models_update_wts[id][key] * grad_dict[key].cpu()
).sum()
elif self.sharing_mode == "weights":
cw_grad += (models_update_wts[id][key] * grad_dict[key].cpu()).sum()
cw_grad += (
models_update_wts[id][key] * grad_dict[key].cpu()
).sum()
else:
raise ValueError("Unknown sharing mode")
collab_weights_grads.append(cw_grad)
Expand All @@ -125,9 +135,13 @@ def print_GPU_memory(self) -> None:
r: int = torch.cuda.memory_reserved(0)
a: int = torch.cuda.memory_allocated(0)
f: int = r - a # free inside reserved
print(f"Client {self.node_id} :GPU memory: reserved {r}, allocated {a}, free {f}")
print(
f"Client {self.node_id} :GPU memory: reserved {r}, allocated {a}, free {f}"
)

def get_collaborator_weights(self, reprs_dict: dict[int, Tensor]) -> dict[int, float]:
def get_collaborator_weights(
self, reprs_dict: dict[int, Tensor]
) -> dict[int, float]:
"""
Returns the weights of the collaborators for the current round.
"""
Expand Down Expand Up @@ -170,46 +184,68 @@ def run_protocol(self) -> None:
round_stats: dict[str, Any] = {"collab_weights": cw}

self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.ROUND_START)
round_stats["train_loss"], round_stats["train_acc"] = self.local_train(epochs_per_round)
round_stats["train_loss"], round_stats["train_acc"] = self.local_train(
epochs_per_round
)
repr: dict[str, Tensor] = self.get_representation()
self.comm_utils.send(dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT)
self.comm_utils.send(
dest=self.server_node, data=repr, tag=self.tag.REPR_ADVERT
)

reprs: list[dict[str, Tensor]] = self.comm_utils.receive(node_ids=self.server_node, tag=self.tag.REPRS_SHARE)
reprs_dict: dict[int, dict[str, Tensor]] = {k: v for k, v in enumerate(reprs, 1)}
reprs: list[dict[str, Tensor]] = self.comm_utils.receive(
node_ids=self.server_node, tag=self.tag.REPRS_SHARE
)
reprs_dict: dict[int, dict[str, Tensor]] = {
k: v for k, v in enumerate(reprs, 1)
}

collab_weights_dict: dict[int, float] = self.get_collaborator_weights(reprs_dict)
collab_weights_dict: dict[int, float] = self.get_collaborator_weights(
reprs_dict
)
models_update_wts: dict[int, dict[str, Tensor]] = reprs_dict

new_wts: dict[str, Tensor] = self.weighted_aggregate(
models_update_wts, collab_weights_dict, self.model_keys_to_ignore
)

if self.sharing_mode == "updates":
new_wts = self.model_utils.substract_model_weights(self.prev_model, new_wts)
new_wts = self.model_utils.substract_model_weights(
self.prev_model, new_wts
)

self.set_model_weights(new_wts, self.model_keys_to_ignore)
round_stats["test_acc"] = self.local_test()
round_stats["validation_loss"], round_stats["validation_acc"] = self.learn_collab_weights(models_update_wts)
round_stats["validation_loss"], round_stats["validation_acc"] = (
self.learn_collab_weights(models_update_wts)
)

print(f"node {self.node_id} weight: {self.collab_weights}")
if round == self.config["T_0"]:
self.filter_out_worse_neighbors(self.config["target_users_after_T_0"])

self.comm_utils.send(dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS)
self.comm_utils.send(
dest=self.server_node, data=round_stats, tag=self.tag.ROUND_STATS
)


class L2CServer(BaseFedAvgServer):
def __init__(self, config: dict[str, Any], comm_utils: CommunicationManager) -> None:
def __init__(
self, config: dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
self.config = config
self.set_model_parameters(config)
self.model_save_path: str = f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
self.model_save_path: str = (
f"{self.config['results_path']}/saved_models/node_{self.node_id}.pt"
)

def test(self) -> float:
"""
Test the model on the server
"""
test_loss, acc = self.model_utils.test(self.model, self._test_loader, self.loss_fn, self.device)
test_loss, acc = self.model_utils.test(
self.model, self._test_loader, self.loss_fn, self.device
)
return acc

def single_round(self) -> list[dict[str, Any]]:
Expand All @@ -218,16 +254,24 @@ def single_round(self) -> list[dict[str, Any]]:
"""
for client_node in self.users:
self.comm_utils.send(dest=client_node, data=None, tag=self.tag.ROUND_START)
self.log_utils.log_console("Server waiting for all clients to finish local training")
self.log_utils.log_console(
"Server waiting for all clients to finish local training"
)

reprs: list[dict[str, Tensor]] = self.comm_utils.all_gather(self.tag.REPR_ADVERT)
reprs: list[dict[str, Tensor]] = self.comm_utils.all_gather(
self.tag.REPR_ADVERT
)
self.log_utils.log_console("Server received all clients models")
self.send_representations(reprs)

round_stats: list[dict[str, Any]] = self.comm_utils.all_gather(self.tag.ROUND_STATS)
round_stats: list[dict[str, Any]] = self.comm_utils.all_gather(
self.tag.ROUND_STATS
)
self.log_utils.log_console("Server received all clients stats")
self.log_utils.log_tb_round_stats(round_stats, ["collab_weights"], self.round)
self.log_utils.log_console(f"Round test acc {[stats['test_acc'] for stats in round_stats]}")
self.log_utils.log_console(
f"Round test acc {[stats['test_acc'] for stats in round_stats]}"
)

return round_stats

Expand All @@ -247,4 +291,3 @@ def run_protocol(self) -> None:
stats_dict["round_step"] = 1
self.log_utils.log_experiments_stats(stats_dict)
self.plot_utils.plot_experiments_stats(stats_dict)

14 changes: 8 additions & 6 deletions src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict
from utils.communication.comm_utils import CommunicationManager
import math
import torch
Expand Down Expand Up @@ -77,7 +77,9 @@ def forward(self, model_dict):

class MetaL2CClient(BaseFedAvgClient):

def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)

self.encoder = ModelEncoder(self.get_model_weights())
Expand Down Expand Up @@ -242,9 +244,7 @@ def run_protocol(self):
)

# Collect the representations from all other nodes from the server
reprs = self.comm_utils.receive(
self.server_node, tag=self.tag.REPRS_SHARE
)
reprs = self.comm_utils.receive(self.server_node, tag=self.tag.REPRS_SHARE)
reprs_dict = {k: rep for k, (rep, _) in enumerate(reprs, 1)}

# Aggregate the representations based on the collab wheigts
Expand Down Expand Up @@ -296,7 +296,9 @@ def run_protocol(self):

class MetaL2CServer(BaseFedAvgServer):

def __init__(self, config: Dict[str, Any], comm_utils: CommunicationManager) -> None:
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
super().__init__(config, comm_utils)
# self.set_parameters()
self.config = config
Expand Down
Loading

0 comments on commit e4cf4e5

Please sign in to comment.