Skip to content

Commit

Permalink
add communicator group
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 13, 2023
1 parent 34a1aca commit 3bf2b1e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 80 deletions.
1 change: 0 additions & 1 deletion python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import typing
from typing import Iterable, Literal, Optional, Tuple, TypeVar, overload

from fate.arch.abc import CSessionABC
from ._cipher import CipherKit
from ._federation import Parties, Party
from ._metrics import InMemoryMetricsHandler, MetricsWrap
Expand Down
3 changes: 3 additions & 0 deletions python/fate/arch/protocol/mpc/communicator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .communicator import Communicator
import contextlib


def get():
if not Communicator.is_initialized():
raise RuntimeError("Crypten not initialized. Please call crypten.init() first.")

return Communicator.get()


82 changes: 54 additions & 28 deletions python/fate/arch/protocol/mpc/communicator/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@
logger = logging.getLogger(__name__)


class CommunicateGroup:
def __init__(self, ranks: List[int], namespace_tensor: NS, namespace_obj: NS):
self.ranks = ranks
self.namespace_tensor = namespace_tensor
self.namespace_obj = namespace_obj

self._prev_group = None

def __str__(self):
return f"CommunicateGroup(ranks={self.ranks})"

def __enter__(self):
# replace the communicator's main group with this group
self._prev_group = Communicator.get().main_group
Communicator.get().main_group = self
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Communicator.get().main_group = self._prev_group


class Communicator:
"""
FATECommunicator is a wrapper around the FATE communicator.
Expand All @@ -24,25 +45,25 @@ class Communicator:
def __init__(
self,
ctx: Context,
namespace_tensor: NS,
namespace_obj: NS,
main_group: CommunicateGroup,
rank_to_party,
rank,
world_size,
):
assert rank in rank_to_party, f"rank {rank} not in rank_to_party: {rank_to_party}"
assert len(rank_to_party) == world_size, f"rank_to_party size {len(rank_to_party)} != world_size {world_size}"
for i in range(world_size):
assert i in rank_to_party, f"rank {i} not in rank_to_party: {rank_to_party}"
self.ctx = ctx
self.namespace_tensor = namespace_tensor
self.namespace_obj = namespace_obj
self.rank_to_party = rank_to_party
self.rank = rank
self.rank_to_party = rank_to_party
self.world_size = world_size
self.main_group = None
self._tensor_send_index = -1
self._tensor_recv_index = -1
self._object_send_index = -1
self._object_recv_index = -1

self._pool = ThreadPoolExecutor(max_workers=2)
self._pool = ThreadPoolExecutor(max_workers=world_size)
self.main_group = main_group

@classmethod
def is_initialized(cls):
Expand All @@ -52,6 +73,14 @@ def is_initialized(cls):
def get(cls) -> "Communicator":
return cls.instance

def new_group(self, ranks: List[int], name: str):
assert len(ranks) > 1, f"new group must have more than 1 rank: {ranks}"
assert all([0 <= rank < self.world_size for rank in ranks]), f"invalid ranks: {ranks}"
assert len(set(ranks)) == len(ranks), f"duplicate ranks: {ranks}"
namespace_tensor = self.ctx.namespace.sub_ns(f"mpc_tensor_{name}")
namespace_obj = self.ctx.namespace.sub_ns(f"mpc_obj_{name}")
return CommunicateGroup(ranks, namespace_tensor, namespace_obj)

def _assert_initialized(self):
assert self.is_initialized(), "initialize the communicator first"

Expand All @@ -65,15 +94,17 @@ def get_world_size(self):

@classmethod
def initialize(cls, ctx: Context, init_ttp):
world_size = ctx.world_size
rank = ctx.local.rank
namespace_tensor = NS(ctx.namespace.sub_ns("mpc_tensor"), 0)
namespace_obj = NS(ctx.namespace.sub_ns("mpc_obj"), 0)
rank_to_party = {p.rank: p.party for p in ctx.parties}
world_size = len(rank_to_party)
namespace_tensor = ctx.namespace.sub_ns("mpc_tensor")
namespace_obj = ctx.namespace.sub_ns("mpc_obj")
main_group = CommunicateGroup(
ranks=list(range(world_size)), namespace_tensor=namespace_tensor, namespace_obj=namespace_obj
)
cls.instance = Communicator(
ctx,
namespace_tensor,
namespace_obj,
main_group,
rank_to_party,
rank,
world_size,
Expand Down Expand Up @@ -251,11 +282,6 @@ def get_communication_stats(self):
"time": self.comm_time,
}

def _log_communication(self, nelement):
"""Updates log of communication statistics."""
self.comm_rounds += 1
self.comm_bytes += nelement * self.BYTES_PER_ELEMENT

def _log_communication_time(self, comm_time):
self.comm_time += comm_time

Expand All @@ -274,35 +300,35 @@ def _get_parties_by_ranks(self, ranks: List[int], namespace: NS):
return self._get_parties([self.rank_to_party[rank] for rank in ranks], namespace)

def _send(self, index, tensor, dst):
parties = self._get_parties_by_rank(dst, self.namespace_tensor)
parties = self._get_parties_by_rank(dst, self.main_group.namespace_tensor)
logger.debug(f"[{self.ctx.local}]sending, index={index}, dst={dst}, parties={parties}")
parties.put(self.namespace_tensor.indexed_ns(index).federation_tag, tensor)
parties.put(self.main_group.namespace_tensor.indexed_ns(index).federation_tag, tensor)

def _send_obj(self, index, obj, dst):
parties = self._get_parties_by_rank(dst, self.namespace_obj)
parties = self._get_parties_by_rank(dst, self.main_group.namespace_obj)
logger.debug(f"[{self.ctx.local}]sending obj, index={index}, dst={dst}, parties={parties}")
parties.put(self.namespace_obj.indexed_ns(index).federation_tag, obj)
parties.put(self.main_group.namespace_obj.indexed_ns(index).federation_tag, obj)

def _recv(self, index, tensor, src):
parties = self._get_parties_by_rank(src, self.namespace_tensor)
parties = self._get_parties_by_rank(src, self.main_group.namespace_tensor)
logger.debug(f"[{self.ctx.local}]receiving, index={index}, src={src}, parties={parties}")
got_tensor = parties.get(self.namespace_tensor.indexed_ns(index).federation_tag)[0]
got_tensor = parties.get(self.main_group.namespace_tensor.indexed_ns(index).federation_tag)[0]
if tensor is None:
return got_tensor
else:
tensor.copy_(got_tensor)
return tensor

def _recv_obj(self, index, src):
parties = self._get_parties_by_rank(src, self.namespace_obj)
parties = self._get_parties_by_rank(src, self.main_group.namespace_obj)
logger.debug(f"[{self.ctx.local}]receiving, index={index}, src={src}, parties={parties}")
got_obj = parties.get(self.namespace_obj.indexed_ns(index).federation_tag)[0]
got_obj = parties.get(self.main_group.namespace_obj.indexed_ns(index).federation_tag)[0]
return got_obj

def _send_many(self, index, tensor, dst_list):
parties = self._get_parties_by_ranks(dst_list, self.namespace_tensor)
parties = self._get_parties_by_ranks(dst_list, self.main_group.namespace_tensor)
logger.debug(f"[{self.ctx.local}]sending, index={index}, dst={dst_list}, parties={parties}")
parties.put(self.namespace_tensor.indexed_ns(index).federation_tag, tensor)
parties.put(self.main_group.namespace_tensor.indexed_ns(index).federation_tag, tensor)


class WaitableFuture:
Expand Down
3 changes: 2 additions & 1 deletion python/fate/arch/protocol/mpc/sshe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def cross_smm(
else:
raise ValueError(f"invalid rank: {ctx.rank}")

z.encoder = FixedPointEncoder(z.encoder._precision_bits + encoder._precision_bits)
with IgnoreEncodings([z]):
z.div_(encoder.scale)
return z

@classmethod
Expand Down
8 changes: 7 additions & 1 deletion python/fate/arch/tensor/distributed/_op_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,11 @@

@implements(torch.broadcast_tensors)
def broadcast_tensors(*input: DTensor):
logger.warning("broadcast_tensors is not implemented")
for t in input:
if not isinstance(t, DTensor):
raise TypeError("broadcast_tensors expects all inputs to be tensors")
shapes = input[0].shardings.shapes
for t in input[1:]:
if t.shardings.shapes != shapes:
raise RuntimeError("broadcast_tensors expects all inputs to be of the same shape")
return input
2 changes: 2 additions & 0 deletions python/fate/ml/mpc/sshe_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def validate_run(self, ctx: Context):
ctx.mpc.info(f"ha.grad={h.grad}", dst=[0])
ctx.mpc.info(f"hb.grad={h.grad}", dst=[1])

import time
time.sleep(3)
ctx.mpc.info(f"==================ground truth==================")
ha = torch.rand(num_samples, in_features_a, requires_grad=True, generator=torch.Generator().manual_seed(0))
hb = torch.rand(num_samples, in_features_b, requires_grad=True, generator=torch.Generator().manual_seed(1))
Expand Down
70 changes: 21 additions & 49 deletions python/fate/ml/mpc/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,7 @@ def __init__(
...

def fit(self, ctx: Context) -> None:
h = ctx.mpc.cond_call(
lambda: torch.rand(10, 4, requires_grad=True, generator=torch.Generator().manual_seed(0)),
lambda: torch.rand(10, 5, requires_grad=True, generator=torch.Generator().manual_seed(1)),
dst=0,
)
ctx.mpc.info(h, dst=[0, 1])
generator = torch.Generator().manual_seed(0)
lr = 0.05
layer = SSHEAggregatorLayer(
ctx, in_features_a=4, in_features_b=5, out_features=2, rank_a=0, rank_b=1, lr=lr, generator=generator
)
ctx.mpc.info(f"wa={layer.get_wa()}")
ctx.mpc.info(f"wb={layer.get_wb()}")
z = layer(h)
ctx.mpc.info(z)
z.sum().backward()
ctx.mpc.info(f"wa={layer.get_wa()}")
ctx.mpc.info(f"wb={layer.get_wb()}")

# validate
h1 = torch.rand(10, 4, requires_grad=True, generator=torch.Generator().manual_seed(0))
h2 = torch.rand(10, 5, requires_grad=True, generator=torch.Generator().manual_seed(1))
w1 = torch.rand(4, 2, requires_grad=True, generator=torch.Generator().manual_seed(0))
w2 = torch.rand(5, 2, requires_grad=True, generator=torch.Generator().manual_seed(0))
z = torch.matmul(h1, w1) + torch.matmul(h2, w2)
z.sum().backward()
w1 = w1 - lr * w1.grad
w2 = w2 - lr * w2.grad
ctx.mpc.info((h1, h2, w1, w2), dst=0)

# self.fit_mul(ctx)
self.fit_matmul(ctx)
# x = _get_left_tensor(ctx, 0)
# alice = ctx.mpc.cond_call(lambda: x, lambda: _get_left_tensor(ctx, is_zero=True), dst=0)
# logger.info(torch.to_local_f(x))
Expand Down Expand Up @@ -84,25 +54,27 @@ def fit_mul(self, ctx: Context):
ctx.mpc.info(f"mul={torch.to_local_f(out)}")

def fit_matmul(self, ctx: Context):
x = _get_left_tensor(ctx, 0)
y = _get_right_tensor(ctx, 1)
expect = torch.matmul(x, y)
logger.info(f"expect={torch.to_local_f(expect)}")

x_alice = ctx.mpc.cond_call(
lambda: _get_left_tensor(ctx, 0), lambda: _get_left_tensor(ctx, is_zero=True), dst=0
)
ctx.mpc.info(torch.to_local_f(x_alice), dst=0)
x_alice_enc = ctx.mpc.cryptensor(x_alice, src=0)

x_bob = ctx.mpc.cond_call(
lambda: _get_right_tensor(ctx, 1), lambda: _get_right_tensor(ctx, is_zero=True), dst=1
)
ctx.mpc.info(torch.to_local_f(x_bob), dst=1)
x_bob_enc = ctx.mpc.cryptensor(x_bob, src=1)

out = x_alice_enc.matmul(x_bob_enc).get_plain_text()
ctx.mpc.info(f"matmul={torch.to_local_f(out)}")
with ctx.mpc.communicator.new_group(ranks=[0,1], name="matmul"):
x = _get_left_tensor(ctx, 0)
y = _get_right_tensor(ctx, 1)
expect = torch.matmul(x, y)
logger.info(f"expect={torch.to_local_f(expect)}")

x_alice = ctx.mpc.cond_call(
lambda: _get_left_tensor(ctx, 0), lambda: _get_left_tensor(ctx, is_zero=True), dst=0
)
ctx.mpc.info(torch.to_local_f(x_alice), dst=0)
x_alice_enc = ctx.mpc.encrypt(x_alice, src=0)

x_bob = ctx.mpc.cond_call(
lambda: _get_right_tensor(ctx, 1), lambda: _get_right_tensor(ctx, is_zero=True), dst=1
)
ctx.mpc.info(torch.to_local_f(x_bob), dst=1)
x_bob_enc = ctx.mpc.encrypt(x_bob, src=1)

out = x_alice_enc.matmul(x_bob_enc).get_plain_text()
ctx.mpc.info(f"matmul={torch.to_local_f(out)}")


def _get_left_tensor(ctx, seed=None, is_zero=False):
Expand Down

0 comments on commit 3bf2b1e

Please sign in to comment.