diff --git a/python/fate/arch/protocol/_dh.py b/python/fate/arch/protocol/_dh.py index b834e1c907..51391ea69c 100644 --- a/python/fate/arch/protocol/_dh.py +++ b/python/fate/arch/protocol/_dh.py @@ -54,11 +54,14 @@ def __init__(self, ranks, prefix: typing.Optional[str] = None): self.prefix = prefix self.ranks = ranks - def secure_aggregate(self, ctx: Context): + def secure_aggregate(self, ctx: Context, ranks: typing.Optional[int] = None): mix_aggregator = MixAggregate() aggregated_weight = 0.0 has_weight = False - for rank in self.ranks: + + if ranks is None: + ranks = self.ranks + for rank in ranks: mix_arrays, weight = ctx.parties[rank].get(self._get_name(self._send_name)) mix_aggregator.aggregate(mix_arrays) if weight is not None: @@ -67,5 +70,5 @@ def secure_aggregate(self, ctx: Context): if not has_weight: aggregated_weight = None aggregated = mix_aggregator.finalize(aggregated_weight) - for rank in self.ranks: + for rank in ranks: ctx.parties[rank].put(self._get_name(self._recv_name), aggregated)