diff --git a/src/algos/fl.py b/src/algos/fl.py index 569213e..721912f 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -169,12 +169,9 @@ def test(self, **kwargs: Any) -> Tuple[float, float, float]: self.stats["test_loss"], self.stats["test_acc"], self.stats["test_time"] = test_loss, test_acc, time_taken return test_loss, test_acc, time_taken - def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""): + def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int): reprs = self.comm_utils.all_gather() - with open(dump_file_name, "wb") as f: - pickle.dump(reprs, f) - # Handle GIA-specific logic if "gia" in self.config: print("Server Running GIA attack") @@ -229,12 +226,13 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round attack_end_round (int): The last round for the attack to be performed. """ - # Normal training when outside the attack range - - if round < attack_start_round or round > attack_end_round: - self.receive_and_aggregate() + # Determine if the attack should be performed + attack_in_progress = self.gia_attacker and attack_start_round <= round <= attack_end_round + + if attack_in_progress: + self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round) else: - self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name) + self.receive_and_aggregate() def run_protocol(self):