Skip to content

Commit

Permalink
add optional ranks for dh aggregator
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jul 21, 2023
1 parent 9eeb6c5 commit aa8fd8a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/fate/arch/protocol/_dh.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def __init__(self, ranks, prefix: typing.Optional[str] = None):
self.prefix = prefix
self.ranks = ranks

def secure_aggregate(self, ctx: Context):
def secure_aggregate(self, ctx: Context, ranks: typing.Optional[int] = None):
mix_aggregator = MixAggregate()
aggregated_weight = 0.0
has_weight = False
for rank in self.ranks:

if ranks is None:
ranks = self.ranks
for rank in ranks:
mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name))
mix_aggregator.aggregate(mix_arrays)
if weight is not None:
Expand All @@ -67,5 +70,5 @@ def secure_aggregate(self, ctx: Context):
if not has_weight:
aggregated_weight = None
aggregated = mix_aggregator.finalize(aggregated_weight)
for rank in self.ranks:
for rank in ranks:
ctx.parties[rank].put(self._get_name(self._recv_name), aggregated)

0 comments on commit aa8fd8a

Please sign in to comment.