Skip to content

Commit

Permalink
added bad weights and sign flipping
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-yuan committed Sep 20, 2024
1 parent 40e6a04 commit 71a282c
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict
import sys
from typing import Any, Dict, List
from torch import Tensor
from torch import Tensor, zeros_like
from utils.communication.comm_utils import CommunicationManager
from utils.log_utils import LogUtils
from algos.base_class import BaseClient, BaseServer
Expand Down Expand Up @@ -55,7 +55,20 @@ def get_representation(self, **kwargs: Any) -> OrderedDict[str, Tensor]:
"""
Share the model weights
"""
return self.model.state_dict() # type: ignore
malicious_type = self.config.get("malicious_type", "normal")

if malicious_type == "normal":
return self.model.state_dict() # type: ignore
elif malicious_type == "bad_weights":
# Set the weights to zero
# TODO: set it to the weight specified in the config
return OrderedDict({key: zeros_like(val) for key, val in self.model.state_dict().items()})
elif malicious_type == "sign_flip":
# Flip the sign of the weights
return OrderedDict({key: -1 * val for key, val in self.model.state_dict().items()})
else:
raise ValueError("Invalid malicious type")


def set_representation(self, representation: OrderedDict[str, Tensor]):
"""
Expand All @@ -70,6 +83,7 @@ def run_protocol(self):
for round in range(start_epochs, total_epochs):
self.local_train(round)
self.local_test()

repr = self.get_representation()

self.client_log_utils.log_summary("Client {} sending done signal to {}".format(self.node_id, self.server_node))
Expand Down

0 comments on commit 71a282c

Please sign in to comment.