Skip to content

Commit

Permalink
fixing lingering import errors (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
photonshi authored Dec 2, 2024
1 parent 906cf84 commit 02a9d55
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 469 deletions.
4 changes: 2 additions & 2 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from utils.types import ConfigType
from utils.dropout_utils import NodeDropout
from utils.gias import gia_main

import torchvision.transforms as T # type: ignore
import os
Expand Down Expand Up @@ -761,7 +760,8 @@ def receive_attack_and_aggregate(self, neighbors: List[int], round: int, num_nei
"""
Receives updates, launches GIA attack when second update is seen from a neighbor
"""
print("CLIENT RECEIVING ATTACK AND AGGREGATING")
from utils.gias import gia_main

if self.is_working:
# Receive the model updates from the neighbors
model_updates = self.comm_utils.receive(node_ids=neighbors)
Expand Down
9 changes: 4 additions & 5 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from algos.attack_bad_weights import BadWeightsAttack
from algos.attack_sign_flip import SignFlipAttack

from utils.gias import gia_main

import pickle

class FedAvgClient(BaseClient):
Expand Down Expand Up @@ -174,9 +172,10 @@ def receive_attack_and_aggregate(self, round: int, attack_start_round: int, atta

# Handle GIA-specific logic
if "gia" in self.config:
from utils.gias import gia_main

print("Server Running GIA attack")
base_params = [key for key, _ in self.model.named_parameters()]
print(base_params)

for rep in reprs:
client_id = rep["sender"]
Expand Down Expand Up @@ -216,7 +215,7 @@ def receive_and_aggregate(self):
avg_wts = self.aggregate(reprs)
self.set_representation(avg_wts)

def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1):
def single_round(self, round: int = 0, attack_start_round: int = 0, attack_end_round: int = 1):
"""
Runs the whole training procedure.
Expand All @@ -227,7 +226,7 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round
"""

# Determine if the attack should be performed
attack_in_progress = self.gia_attacker and attack_start_round <= round <= attack_end_round
attack_in_progress = hasattr(self, 'gia_attacker') and self.gia_attacker and attack_start_round <= round <= attack_end_round

if attack_in_progress:
self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round)
Expand Down
Loading

0 comments on commit 02a9d55

Please sign in to comment.