Skip to content

Commit

Permalink
Merge pull request #151 from photonshi/fix_fl
Browse files Browse the repository at this point in the history
fixed attack invocation in fl
  • Loading branch information
photonshi authored Dec 2, 2024
2 parents 5a78e1c + 43a663a commit 906cf84
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,9 @@ def test(self, **kwargs: Any) -> Tuple[float, float, float]:
self.stats["test_loss"], self.stats["test_acc"], self.stats["test_time"] = test_loss, test_acc, time_taken
return test_loss, test_acc, time_taken

def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""):
def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int):
reprs = self.comm_utils.all_gather()

with open(dump_file_name, "wb") as f:
pickle.dump(reprs, f)

# Handle GIA-specific logic
if "gia" in self.config:
print("Server Running GIA attack")
Expand Down Expand Up @@ -229,12 +226,13 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round
attack_end_round (int): The last round for the attack to be performed.
"""

# Normal training when outside the attack range

if round < attack_start_round or round > attack_end_round:
self.receive_and_aggregate()
# Determine if the attack should be performed
attack_in_progress = self.gia_attacker and attack_start_round <= round <= attack_end_round

if attack_in_progress:
self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round)
else:
self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name)
self.receive_and_aggregate()


def run_protocol(self):
Expand Down

0 comments on commit 906cf84

Please sign in to comment.