Skip to content

Commit

Permalink
Refactor and Extend Logging, Error Handling, and Context Passing
Browse files Browse the repository at this point in the history
This commit encompasses a number of changes primarily aimed at enhancing logging, error handling, and context passing across various modules. Key changes include:

Introduced logging to MPC and Toy classes to better trace and debug the code execution.
Extended error handling to provide more informative error messages and ensure robustness.
Refactored context passing in arithmetic operations to ensure consistency and clarity.
Removed DistributedArithmeticSharedTensor as it was no longer needed.
Added new utility functions for more efficient tensor operations and random number generation, enhancing code modularity and readability.
Improved handling of tensor operations in DTensor to better support in-place operations and tensor metadata extraction.
These changes collectively enhance the code maintainability, ease debugging, and ensure better error handling across the codebase, particularly in the MPC module and arithmetic operations.

Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Nov 1, 2023
1 parent badc710 commit ea0ca00
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 721 deletions.
41 changes: 41 additions & 0 deletions python/fate/arch/context/_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from fate.arch.protocol.mpc.communicator import Communicator
from fate.arch.tensor import mpc
from fate.arch.tensor.mpc.cryptensor import CrypTensor
import logging

if typing.TYPE_CHECKING:
from fate.arch.context import Context
logger = logging.getLogger(__name__)


class MPC:
Expand Down Expand Up @@ -64,6 +66,45 @@ def is_encrypted_tensor(cls, obj):
"""
return isinstance(obj, CrypTensor)

def print(self, message, dst=[0], print_func=None):
if print_func is None:
print_func = print
if self.rank in dst:
print_func(message)

def info(self, message, dst=[0]):
if isinstance(dst, int):
dst = [dst]
if self.rank in dst:
logger.info(msg=message, stacklevel=2)

def debug(self, message, dst=[0]):
if isinstance(dst, int):
dst = [dst]
if self.rank in dst:
logger.debug(msg=message, stacklevel=2)

def warning(self, message, dst=[0]):
if isinstance(dst, int):
dst = [dst]
if self.rank in dst:
logger.warning(msg=message, stacklevel=2)

def error(self, message, dst=[0]):
if isinstance(dst, int):
dst = [dst]
if self.rank in dst:
logger.error(msg=message, stacklevel=2)

def cond_call(self, func1, func2=None, dst=0):
"""
Calls func1 if rank == dst, otherwise calls func2.
"""
if self.rank == dst:
return func1()
else:
return func2() if func2 is not None else None


def ttp_required():
from fate.arch.tensor.mpc.config import cfg
Expand Down
1 change: 1 addition & 0 deletions python/fate/arch/protocol/mpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"primitives",
"provider",
"ptype",
"generators"
]

# the different private type attributes of an mpc encrypted tensor
Expand Down
7 changes: 2 additions & 5 deletions python/fate/arch/protocol/mpc/mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,9 @@ def __init__(self, ctx, tensor, ptype=Ptype.arithmetic, device=None, *args, **kw
# create the MPCTensor:
if tensor is []:
self._tensor = torch.tensor([], device=device)
elif isinstance(tensor, DTensor):
tensor_type = ptype.to_tensor(distributed=True)
self._tensor = tensor_type(ctx=ctx, tensor=tensor, device=device, *args, **kwargs)
else:
tensor_type = ptype.to_tensor()
self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs)
self._tensor = tensor_type(ctx=ctx, tensor=tensor, device=device, *args, **kwargs)
self.ptype = ptype
self.ctx = ctx

Expand All @@ -83,7 +80,7 @@ def from_shares(share, precision=None, ptype=Ptype.arithmetic):
def clone(self):
"""Create a deep copy of the input tensor."""
# TODO: Rename this to __deepcopy__()?
result = MPCTensor([])
result = MPCTensor(self.ctx, [])
result._tensor = self._tensor.clone()
result.ptype = self.ptype
return result
Expand Down
3 changes: 1 addition & 2 deletions python/fate/arch/protocol/mpc/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from .arithmetic import ArithmeticSharedTensor
from .binary import BinarySharedTensor
from .distributed_arithmetic import DistributedArithmeticSharedTensor


__all__ = ["ArithmeticSharedTensor", "BinarySharedTensor", "DistributedArithmeticSharedTensor"]
__all__ = ["ArithmeticSharedTensor", "BinarySharedTensor"]
26 changes: 16 additions & 10 deletions python/fate/arch/protocol/mpc/primitives/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fate.arch.tensor.mpc.encoder import FixedPointEncoder
from fate.arch.tensor.mpc.functions import regular
from . import beaver, replicated # noqa: F401
from fate.arch.context import Context

SENTINEL = -1

Expand All @@ -35,6 +36,7 @@ class ArithmeticSharedTensor(object):
# constructors:
def __init__(
self,
ctx,
tensor=None,
size=None,
broadcast_size=False,
Expand All @@ -60,6 +62,8 @@ def __init__(
the tensor. If `device` is unspecified, it is set to `tensor.device`.
"""

assert isinstance(ctx, Context), "ctx must be a Context object"
self._ctx = ctx
# do nothing if source is sentinel:
if src == SENTINEL:
return
Expand Down Expand Up @@ -93,10 +97,11 @@ def __init__(
size = comm.get().broadcast_obj(size, src)

# generate pseudo-random zero sharing (PRZS) and add source's tensor:
self.share = ArithmeticSharedTensor.PRZS(size, device=device).share
self.share = ArithmeticSharedTensor.PRZS(ctx, size, device=device).share
if self.rank == src:
self.share += tensor


@staticmethod
def new(*args, **kwargs):
"""
Expand Down Expand Up @@ -149,7 +154,7 @@ def from_shares(share, precision=None, device=None):
return result

@staticmethod
def PRZS(*size, device=None):
def PRZS(ctx, *size, device=None):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
Expand All @@ -159,24 +164,24 @@ def PRZS(*size, device=None):
"""
from fate.arch.protocol.mpc import generators

tensor = ArithmeticSharedTensor(src=SENTINEL)
tensor = ArithmeticSharedTensor(ctx, src=SENTINEL)
if device is None:
device = torch.device("cpu")
elif isinstance(device, str):
device = torch.device(device)
g0 = generators["prev"][device]
g1 = generators["next"][device]
current_share = generate_random_ring_element(*size, generator=g0, device=device)
next_share = generate_random_ring_element(*size, generator=g1, device=device)
current_share = generate_random_ring_element(ctx, *size, generator=g0, device=device)
next_share = generate_random_ring_element(ctx, *size, generator=g1, device=device)
tensor.share = current_share - next_share
return tensor

@staticmethod
def PRSS(*size, device=None):
def PRSS(ctx, *size, device=None):
"""
Generates a Pseudo-random Secret Share from a set of random arithmetic shares
"""
share = generate_random_ring_element(*size, device=device)
share = generate_random_ring_element(ctx, *size, device=device)
tensor = ArithmeticSharedTensor.from_shares(share=share)
return tensor

Expand All @@ -186,13 +191,13 @@ def rank(self):

def shallow_copy(self):
"""Create a shallow copy"""
result = ArithmeticSharedTensor(src=SENTINEL)
result = ArithmeticSharedTensor(ctx=self._ctx, src=SENTINEL)
result.encoder = self.encoder
result._tensor = self._tensor
return result

def clone(self):
result = ArithmeticSharedTensor(src=SENTINEL)
result = ArithmeticSharedTensor(ctx=self._ctx, src=SENTINEL)
result.encoder = self.encoder
result._tensor = self._tensor.clone()
return result
Expand Down Expand Up @@ -356,7 +361,8 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C
result.share = getattr(result.share, op)(y.share)
else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
protocol = globals()[cfg.mpc.protocol]
result.share.set_(getattr(protocol, op)(result, y, *args, **kwargs).share.data)
tmp = getattr(protocol, op)(self._ctx, result, y, *args, **kwargs)
result.share = tmp.share
else:
raise TypeError("Cannot %s %s with %s" % (op, type(y), type(self)))

Expand Down
33 changes: 16 additions & 17 deletions python/fate/arch/protocol/mpc/primitives/beaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
tensor.encoder._scale = self.encodings_cache[i]


def __beaver_protocol(op, x, y, *args, **kwargs):
def __beaver_protocol(ctx, op, x, y, *args, **kwargs):
"""Performs Beaver protocol for additively secret-shared tensors x and y
1. Obtain uniformly random sharings [a],[b] and [c] = [a * b]
Expand All @@ -49,7 +49,7 @@ def __beaver_protocol(op, x, y, *args, **kwargs):
raise ValueError(f"x lives on device {x.device} but y on device {y.device}")

provider = mpc.get_default_provider()
a, b, c = provider.generate_additive_triple(x.size(), y.size(), op, device=x.device, *args, **kwargs)
a, b, c = provider.generate_additive_triple(ctx, x.size(), y.size(), op, device=x.device, *args, **kwargs)

from .arithmetic import ArithmeticSharedTensor

Expand All @@ -58,9 +58,9 @@ def __beaver_protocol(op, x, y, *args, **kwargs):
Reference: "Multiparty Computation from Somewhat Homomorphic Encryption"
Link: https://eprint.iacr.org/2011/535.pdf
"""
f, g, h = provider.generate_additive_triple(x.size(), y.size(), op, device=x.device, *args, **kwargs)
f, g, h = provider.generate_additive_triple(ctx, x.size(), y.size(), op, device=x.device, *args, **kwargs)

t = ArithmeticSharedTensor.PRSS(a.size(), device=x.device)
t = ArithmeticSharedTensor.PRSS(ctx, a.size(), device=x.device)
t_plain_text = t.get_plain_text()

rho = (t_plain_text * a - f).get_plain_text()
Expand All @@ -83,29 +83,28 @@ def __beaver_protocol(op, x, y, *args, **kwargs):
return c


def mul(x, y):
raise NotImplementedError("mul not implemented.")
return __beaver_protocol("mul", x, y)
def mul(ctx, x, y):
return __beaver_protocol(ctx, "mul", x, y)


def matmul(x, y):
return __beaver_protocol("matmul", x, y)
def matmul(ctx, x, y):
return __beaver_protocol(ctx, "matmul", x, y)


def conv1d(x, y, **kwargs):
return __beaver_protocol("conv1d", x, y, **kwargs)
def conv1d(ctx, x, y, **kwargs):
return __beaver_protocol(ctx, "conv1d", x, y, **kwargs)


def conv2d(x, y, **kwargs):
return __beaver_protocol("conv2d", x, y, **kwargs)
def conv2d(ctx, x, y, **kwargs):
return __beaver_protocol(ctx, "conv2d", x, y, **kwargs)


def conv_transpose1d(x, y, **kwargs):
return __beaver_protocol("conv_transpose1d", x, y, **kwargs)
def conv_transpose1d(ctx, x, y, **kwargs):
return __beaver_protocol(ctx, "conv_transpose1d", x, y, **kwargs)


def conv_transpose2d(x, y, **kwargs):
return __beaver_protocol("conv_transpose2d", x, y, **kwargs)
def conv_transpose2d(ctx, x, y, **kwargs):
return __beaver_protocol(ctx, "conv_transpose2d", x, y, **kwargs)


def square(x):
Expand Down
Loading

0 comments on commit ea0ca00

Please sign in to comment.