From 4b01c3c37792a3b8e6de15521725fe7f6e7a1a8a Mon Sep 17 00:00:00 2001 From: weiwee Date: Thu, 22 Sep 2022 03:06:15 -0800 Subject: [PATCH] feat: add basic tensor and context impl Signed-off-by: weiwee --- python/fate/arch/context/__init__.py | 3 + python/fate/arch/context/_context.py | 241 +++++++++ python/fate/arch/context/_federation.py | 270 ++++++++++ python/fate/arch/context/_namespace.py | 29 + python/fate/arch/federation/_parties.py | 84 +++ python/fate/arch/tensor/__init__.py | 20 + python/fate/arch/tensor/_dataloader.py | 55 ++ python/fate/arch/tensor/_federation.py | 7 + python/fate/arch/tensor/_parties.py | 84 +++ python/fate/arch/tensor/_tensor.py | 508 ++++++++++++++++++ python/fate/arch/tensor/abc/block.py | 119 ++++ python/fate/arch/tensor/abc/tensor.py | 118 ++++ python/fate/arch/tensor/functional.py | 5 + python/fate/arch/tensor/impl/__init__.py | 0 .../arch/tensor/impl/blocks/_metaclass.py | 365 +++++++++++++ .../tensor/impl/blocks/cpu_paillier_block.py | 50 ++ .../blocks/multithread_cpu_paillier_block.py | 50 ++ .../blocks/python_paillier_block/__init__.py | 7 + .../python_paillier_block/_fate_paillier.py | 350 ++++++++++++ .../python_paillier_block/_fixedpoint.py | 384 +++++++++++++ .../python_paillier_block/_gmpy_math.py | 134 +++++ .../_python_paillier_block.py | 206 +++++++ .../arch/tensor/impl/tensor/_metaclass.py | 169 ++++++ .../arch/tensor/impl/tensor/distributed.py | 280 ++++++++++ .../impl/tensor/multithread_cpu_tensor.py | 38 ++ .../tensor/impl/tensor/row_distributed.py | 236 ++++++++ python/fate/arch/tensor/ops/__init__.py | 2 + 27 files changed, 3814 insertions(+) create mode 100644 python/fate/arch/context/__init__.py create mode 100644 python/fate/arch/context/_context.py create mode 100644 python/fate/arch/context/_federation.py create mode 100644 python/fate/arch/context/_namespace.py create mode 100644 python/fate/arch/federation/_parties.py create mode 100644 python/fate/arch/tensor/__init__.py create mode 100644 python/fate/arch/tensor/_dataloader.py create mode 100644 python/fate/arch/tensor/_federation.py create mode 100644 python/fate/arch/tensor/_parties.py create mode 100644 python/fate/arch/tensor/_tensor.py create mode 100644 python/fate/arch/tensor/abc/block.py create mode 100644 python/fate/arch/tensor/abc/tensor.py create mode 100644 python/fate/arch/tensor/functional.py create mode 100644 python/fate/arch/tensor/impl/__init__.py create mode 100644 python/fate/arch/tensor/impl/blocks/_metaclass.py create mode 100644 python/fate/arch/tensor/impl/blocks/cpu_paillier_block.py create mode 100644 python/fate/arch/tensor/impl/blocks/multithread_cpu_paillier_block.py create mode 100644 python/fate/arch/tensor/impl/blocks/python_paillier_block/__init__.py create mode 100644 python/fate/arch/tensor/impl/blocks/python_paillier_block/_fate_paillier.py create mode 100644 python/fate/arch/tensor/impl/blocks/python_paillier_block/_fixedpoint.py create mode 100644 python/fate/arch/tensor/impl/blocks/python_paillier_block/_gmpy_math.py create mode 100644 python/fate/arch/tensor/impl/blocks/python_paillier_block/_python_paillier_block.py create mode 100644 python/fate/arch/tensor/impl/tensor/_metaclass.py create mode 100644 python/fate/arch/tensor/impl/tensor/distributed.py create mode 100644 python/fate/arch/tensor/impl/tensor/multithread_cpu_tensor.py create mode 100644 python/fate/arch/tensor/impl/tensor/row_distributed.py create mode 100644 python/fate/arch/tensor/ops/__init__.py diff --git a/python/fate/arch/context/__init__.py b/python/fate/arch/context/__init__.py new file mode 100644 index 0000000000..e3784c80f8 --- /dev/null +++ b/python/fate/arch/context/__init__.py @@ -0,0 +1,3 @@ +from ._context import Context, Metric, MetricMeta, Namespace + +__all__ = ["Context", "Namespace", "MetricMeta", "Metric"] diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py new file mode 100644 index 0000000000..fce1b79ca3 --- /dev/null +++ b/python/fate/arch/context/_context.py @@ -0,0 +1,241 @@ +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from logging import Logger, disable, getLogger +from typing import List, Literal, Optional, Tuple, Iterator + +from fate.interface import LOGMSG, Anonymous, Cache, CheckpointManager +from fate.interface import Context as ContextInterface +from fate.interface import Logger as LoggerInterface +from fate.interface import Metric as MetricInterface +from fate.interface import MetricMeta as MetricMetaInterface +from fate.interface import Metrics, Summary +from fate.interface import ComputingEngine +from ..session import Session + +from ._federation import GC, FederationEngine +from ._namespace import Namespace +from ..common._parties import PartiesInfo, Party + + +@dataclass +class Metric(MetricInterface): + key: str + value: float + timestamp: Optional[float] = None + + +class MetricMeta(MetricMetaInterface): + def __init__(self, name: str, metric_type: str, extra_metas: Optional[dict] = None): + self.name = name + self.metric_type = metric_type + self.metas = {} + self.extra_metas = extra_metas + + def update_metas(self, metas: dict): + self.metas.update(metas) + + +class DummySummary(Summary): + """ + dummy summary save nowhre + """ + + def __init__(self) -> None: + self._summary = {} + + @property + def summary(self): + return self._summary + + def save(self): + pass + + def reset(self, summary: dict): + self._summary = summary + + def add(self, key: str, value): + self._summary[key] = value + + +class DummyMetrics(Metrics): + def __init__(self) -> None: + self._data = [] + self._meta = [] + + def log(self, name: str, namespace: str, data: List[Metric]): + self._data.append((name, namespace, data)) + + def log_meta(self, name: str, namespace: str, meta: MetricMeta): + self._meta.append((name, namespace, meta)) + + def log_warmstart_init_iter(self, iter_num): # FIXME: strange here + ... + + +class DummyCache(Cache): + def __init__(self) -> None: + self.cache = [] + + def add_cache(self, key, value): + self.cache.append((key, value)) + + +# FIXME: vary complex to use, may take times to fix +class DummyAnonymous(Anonymous): + ... + + +class DummyCheckpointManager(CheckpointManager): + ... + + +class DummyLogger(LoggerInterface): + def __init__( + self, + context_name: Optional[str] = None, + namespace: Optional[Namespace] = None, + level=logging.DEBUG, + disable_buildin=True, + ) -> None: + if disable_buildin: + self._disable_buildin() + + self.logger = getLogger("fate.dummy") + self.namespace = namespace + self.context_name = context_name + + self.logger.setLevel(level) + + formats = [] + if self.context_name is not None: + formats.append("%(context_name)s") + if self.namespace is not None: + formats.append("%(namespace)s") + formats.append("%(pathname)s:%(lineno)s - %(levelname)s - %(message)s") + formatter = logging.Formatter(" - ".join(formats)) + + # console + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + @classmethod + def _disable_buildin(cls): + from ..common.log import getLogger + + logger = getLogger() + logger.disabled = True + + def log(self, level: int, msg: LOGMSG): + if Logger.isEnabledFor(self.logger, level): + if callable(msg): + msg = msg() + extra = {} + if self.namespace is not None: + extra["namespace"] = self.namespace.namespace + if self.context_name is not None: + extra["context_name"] = self.context_name + self.logger.log(level, msg, stacklevel=3, extra=extra) + + def info(self, msg: LOGMSG): + return self.log(logging.INFO, msg) + + def debug(self, msg: LOGMSG): + return self.log(logging.DEBUG, msg) + + def error(self, msg: LOGMSG): + return self.log(logging.ERROR, msg) + + def warning(self, msg: LOGMSG): + return self.log(logging.WARNING, msg) + + +class Context(ContextInterface): + """ + implement fate.interface.ContextInterface + + Note: most parameters has default dummy value, + which is convenient when used in script. + please pass in custom implements as you wish + """ + + def __init__( + self, + context_name: Optional[str] = None, + computing: Optional[ComputingEngine] = None, + federation: Optional[FederationEngine] = None, + summary: Summary = DummySummary(), + metrics: Metrics = DummyMetrics(), + cache: Cache = DummyCache(), + anonymous_generator: Anonymous = DummyAnonymous(), + checkpoint_manager: CheckpointManager = DummyCheckpointManager(), + log: Optional[LoggerInterface] = None, + disable_buildin_logger=True, # FIXME: just clear old loggers, remove in future + namespace: Optional[Namespace] = None, + ) -> None: + self.context_name = context_name + self.summary = summary + self.metrics = metrics + self.cache = cache + self.anonymous_generator = anonymous_generator + self.checkpoint_manager = checkpoint_manager + + if namespace is None: + namespace = Namespace() + self.namespace = namespace + + if log is None: + log = DummyLogger( + context_name, self.namespace, disable_buildin=disable_buildin_logger + ) + self.log = log + + self._computing = computing + self._federation = federation + self._session = Session() + self._gc = GC() + + def init_computing(self, computing_session_id=None): + self._session.init_computing(computing_session_id=computing_session_id) + + def init_federation( + self, + federation_id, + local_party: Tuple[Literal["guest", "host", "arbiter"], str], + parties: List[Tuple[Literal["guest", "host", "arbiter"], str]], + ): + if self._federation is None: + self._federation = FederationEngine( + federation_id, local_party, parties, self, self._session, self.namespace + ) + + @contextmanager + def sub_ctx(self, namespace) -> Iterator["Context"]: + with self.namespace.into_subnamespace(namespace): + try: + yield self + finally: + ... + + @property + def guest(self): + return self._get_party_util().guest + + @property + def hosts(self): + return self._get_party_util().hosts + + @property + def arbiter(self): + return self._get_party_util().arbiter + + @property + def parties(self): + return self._get_party_util().parties + + def _get_party_util(self) -> FederationEngine: + if self._federation is None: + raise RuntimeError("federation session not init") + return self._federation diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py new file mode 100644 index 0000000000..3273f5362b --- /dev/null +++ b/python/fate/arch/context/_federation.py @@ -0,0 +1,270 @@ +from typing import Callable, List, Literal, Optional, Tuple, TypeVar + +from fate.interface import FPTensor +from fate.interface import Future as FutureInterface +from fate.interface import Futures as FuturesInterface +from fate.interface import Parties as PartiesInterface +from fate.interface import Party as PartyInterface +from fate.interface import PHEEncryptor, PHETensor +from fate.interface import FederationEngine as FederationEngineInterface + +from ..common import Party as PartyMeta +from ..federation.transfer_variable import IterationGC +from ._namespace import Namespace + + +class FederationDeserializer: + def do_deserialize(self, ctx, party): + ... + + @classmethod + def make_frac_key(cls, base_key, frac_key): + return f"{base_key}__frac__{frac_key}" + + +T = TypeVar("T") + + +class Future(FutureInterface): + def __init__(self, inside) -> None: + self._inside = inside + + def unwrap_tensor(self) -> "FPTensor": + + assert isinstance(self._inside, FPTensor) + return self._inside + + def unwrap_phe_encryptor(self) -> "PHEEncryptor": + assert isinstance(self._inside, PHEEncryptor) + return self._inside + + def unwrap_phe_tensor(self) -> "PHETensor": + + assert isinstance(self._inside, PHETensor) + return self._inside + + def unwrap(self, check: Optional[Callable[[T], bool]] = None) -> T: + if check is not None and not check(self._inside): + raise TypeError(f"`{self._inside}` check failed") + return self._inside + + +class Futures(FuturesInterface): + def __init__(self, insides) -> None: + self._insides = insides + + def unwrap_tensors(self) -> List["FPTensor"]: + + for t in self._insides: + assert isinstance(t, FPTensor) + return self._insides + + def unwrap_phe_tensors(self) -> List["PHETensor"]: + + for t in self._insides: + assert isinstance(t, PHETensor) + return self._insides + + def unwrap(self, check: Optional[Callable[[T], bool]] = None) -> List[T]: + if check is not None: + for i, t in enumerate(self._insides): + if not check(t): + raise TypeError(f"{i}th element `{self._insides}` check failed") + return self._insides + + +class GC: + def __init__(self) -> None: + self._push_gc_dict = {} + self._pull_gc_dict = {} + + def get_or_set_push_gc(self, key): + if key not in self._push_gc_dict: + self._push_gc_dict[key] = IterationGC() + return self._push_gc_dict[key] + + def get_or_set_pull_gc(self, key): + if key not in self._pull_gc_dict: + self._pull_gc_dict[key] = IterationGC() + return self._pull_gc_dict[key] + + +class FederationParty(PartyInterface): + def __init__( + self, ctx, federation, party: Tuple[str, str], namespace, gc: GC + ) -> None: + self.ctx = ctx + self.federation = federation + self.party = PartyMeta(party[0], party[1]) + self.namespace = namespace + self.gc = gc + + def push(self, name: str, value): + return _push( + self.ctx, + self.federation, + name, + self.namespace, + [self.party], + self.gc, + value, + ) + + def pull(self, name: str) -> Future: + return Future( + _pull( + self.ctx, self.federation, name, self.namespace, [self.party], self.gc + )[0] + ) + + +class FederationParties(PartiesInterface): + def __init__( + self, + ctx, + federation, + parties: List[Tuple[str, str]], + namespace: Namespace, + gc: GC, + ) -> None: + self.ctx = ctx + self.federation = federation + self.parties = [PartyMeta(party[0], party[1]) for party in parties] + self.namespace = namespace + self.gc = gc + + def __call__(self, key: int) -> FederationParty: + return FederationParty( + self.ctx, + self.federation, + self.parties[key].as_tuple(), + self.namespace, + self.gc, + ) + + def push(self, name: str, value): + return _push( + self.ctx, + self.federation, + name, + self.namespace, + self.parties, + self.gc, + value, + ) + + def pull(self, name: str) -> Futures: + return Futures( + _pull( + self.ctx, self.federation, name, self.namespace, self.parties, self.gc + ) + ) + + +def _push( + ctx, + federation, + name: str, + namespace: Namespace, + parties: List[PartyMeta], + gc: GC, + value, +): + if hasattr(value, "__federation_hook__"): + value.__federation_hook__(ctx, name, parties) + else: + federation.remote( + v=value, + name=name, + tag=namespace.fedeation_tag(), + parties=parties, + gc=gc.get_or_set_push_gc(name), + ) + + +def _pull( + ctx, federation, name: str, namespace: Namespace, parties: List[PartyMeta], gc: GC +): + raw_values = federation.get( + name=name, + tag=namespace.fedeation_tag(), + parties=parties, + gc=gc.get_or_set_pull_gc(name), + ) + values = [] + for party, raw_value in zip(parties, raw_values): + if isinstance(raw_value, FederationDeserializer): + values.append(raw_value.do_deserialize(ctx, party)) + else: + values.append(raw_value) + return values + + +class FederationEngine(FederationEngineInterface): + def __init__( + self, + federation_id: str, + local_party: Tuple[Literal["guest", "host", "arbiter"], str], + parties: Optional[List[Tuple[Literal["guest", "host", "arbiter"], str]]], + ctx, + session, # should remove + namespace: Namespace, + ): + if parties is None: + parties = [] + if local_party not in parties: + parties.append(local_party) + self._local = local_party + self._parties = parties + self._role_to_parties = {} + for (role, party_id) in self._parties: + self._role_to_parties.setdefault(role, []).append(party_id) + + # walkround, temp + from ..common._parties import Party, PartiesInfo + + local = Party(local_party[0], local_party[1]) + role_to_parties = {} + for role, party_id in [local_party, *parties]: + role_to_parties.setdefault(role, []).append(Party(role, party_id)) + session.init_federation( + federation_session_id=federation_id, + parties_info=PartiesInfo(local, role_to_parties), + ) + self.federation = session.federation + + self.ctx = ctx + self.namespace = namespace + self.gc = GC() + + @property + def guest(self) -> PartyInterface: + party = self._role("guest")[0] + return FederationParty( + self.ctx, self.federation, party, self.namespace, self.gc + ) + + @property + def hosts(self) -> PartiesInterface: + parties = self._role("host") + return FederationParties( + self.ctx, self.federation, parties, self.namespace, self.gc + ) + + @property + def arbiter(self) -> PartyInterface: + party = self._role("arbiter")[0] + return FederationParty( + self.ctx, self.federation, party, self.namespace, self.gc + ) + + @property + def parties(self) -> PartiesInterface: + return FederationParties( + self.ctx, self.federation, self._parties, self.namespace, self.gc + ) + + def _role(self, role: str) -> List: + if role not in self._role_to_parties: + raise RuntimeError(f"no {role} party has configurated") + return [(role, party_id) for party_id in self._role_to_parties[role]] diff --git a/python/fate/arch/context/_namespace.py b/python/fate/arch/context/_namespace.py new file mode 100644 index 0000000000..068c941946 --- /dev/null +++ b/python/fate/arch/context/_namespace.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + + +class Namespace: + """ + Summary, Metrics may be namespace awared: + ``` + namespace = Namespace() + ctx = Context(...summary=XXXSummary(namespace)) + ``` + """ + + def __init__(self) -> None: + self.namespaces = [] + + @contextmanager + def into_subnamespace(self, subnamespace: str): + self.namespaces.append(subnamespace) + try: + yield self + finally: + self.namespaces.pop() + + @property + def namespace(self): + return ".".join(self.namespaces) + + def fedeation_tag(self) -> str: + return ".".join(self.namespaces) diff --git a/python/fate/arch/federation/_parties.py b/python/fate/arch/federation/_parties.py new file mode 100644 index 0000000000..ffe0343c22 --- /dev/null +++ b/python/fate/arch/federation/_parties.py @@ -0,0 +1,84 @@ +import enum +from typing import List + +from fate.arch.common import Party +from fate.arch.session import get_parties + + +class Parties: + def __init__(self, flag) -> None: + self._flag = flag + + def contains_hosts(self) -> bool: + return bool(self._flag & 1) + + def contains_arbiter(self) -> bool: + return bool(self._flag & 2) + + def contains_guest(self) -> bool: + return bool(self._flag & 4) + + def contains_host(self) -> bool: + return bool(self._flag & 8) + + @property + def indexes(self) -> List[int]: + return [i for i, e in enumerate(bin(self._flag)[::-1]) if e == "1"] + + @classmethod + def get_name(cls, i): + if i < 4: + return {0: "HOSTS", 1: "ARBITER", 2: "GUEST", 3: "HOST"}[i] + else: + return f"HOST{i-3}" + + def __or__(self, other): + return Parties(self._flag | other._flag) + + def __ror__(self, other): + return Parties(self._flag | other._flag) + + def __hash__(self) -> int: + return self._flag + + def __eq__(self, o) -> bool: + return self._flag == o._flag + + def __str__(self): + readable = "|".join([self.get_name(i) for i in self.indexes]) + return f"4b}): {readable}>" + + def __repr__(self): + return self.__str__() + + def __getitem__(self, key): + if self._flag == 1 and isinstance(key, int) and key >= 0: + return Parties(1 << (key + 3)) + raise TypeError("not subscriptable") + + def _get_role_parties(self, role: str): + return get_parties().roles_to_parties([role], strict=False) + + def get_parties(self) -> List[Party]: + parties = [] + if self._flag & 2: + parties.extend(self._get_role_parties("arbiter")) + if self._flag & 4: + parties.extend(self._get_role_parties("guest")) + if self._flag & 1: + parties.extend(self._get_role_parties("host")) + else: + host_bit_int = self._flag >> 3 + if host_bit_int: + hosts = self._get_role_parties("host") + for i, e in enumerate(bin(host_bit_int)[::-1]): + if e == "1": + parties.append(hosts[i]) + return parties + + +class PreludeParty(Parties, enum.Flag): + HOSTS = 1 + ARBITER = 2 + GUEST = 4 + HOST = 8 diff --git a/python/fate/arch/tensor/__init__.py b/python/fate/arch/tensor/__init__.py new file mode 100644 index 0000000000..2ab55e0f39 --- /dev/null +++ b/python/fate/arch/tensor/__init__.py @@ -0,0 +1,20 @@ +from ._dataloader import LabeledDataloaderWrapper, UnlabeledDataloaderWrapper +from ._parties import Parties, PreludeParty +from ._tensor import CipherKind, Context, FPTensor, PHETensor + +ARBITER = PreludeParty.ARBITER +GUEST = PreludeParty.GUEST +HOST = PreludeParty.HOST + +__all__ = [ + "FPTensor", + "PHETensor", + "Parties", + "ARBITER", + "GUEST", + "HOST", + "Context", + "LabeledDataloaderWrapper", + "UnlabeledDataloaderWrapper", + "CipherKind", +] diff --git a/python/fate/arch/tensor/_dataloader.py b/python/fate/arch/tensor/_dataloader.py new file mode 100644 index 0000000000..1172c3f0f2 --- /dev/null +++ b/python/fate/arch/tensor/_dataloader.py @@ -0,0 +1,55 @@ +import typing + +from ._tensor import FPTensor + + +class LabeledDataloaderWrapper: + """ + wrapper to transform data_instance to tensor-frendly Dataloader + """ + + def __init__( + self, + data_instance, + max_iter, + batch_size=-1, + with_intercept=False, + shuffle=False, + ): + ... + + @property + def shape(self) -> typing.Tuple[int, int]: + ... + + def next_batch(self) -> typing.Tuple[FPTensor, FPTensor]: + ... + + def has_next(self) -> bool: + ... + + +class UnlabeledDataloaderWrapper: + """ + wrapper to transform data_instance to tensor-frendly Dataloader + """ + + def __init__( + self, + data_instance, + max_iter, + batch_size=-1, + with_intercept=False, + shuffle=False, + ): + ... + + @property + def shape(self) -> typing.Tuple[int, int]: + ... + + def next_batch(self) -> FPTensor: + ... + + def has_next(self) -> bool: + ... diff --git a/python/fate/arch/tensor/_federation.py b/python/fate/arch/tensor/_federation.py new file mode 100644 index 0000000000..3687d00f50 --- /dev/null +++ b/python/fate/arch/tensor/_federation.py @@ -0,0 +1,7 @@ +class FederationDeserializer: + def do_deserialize(self, ctx, party): + ... + + @classmethod + def make_frac_key(cls, base_key, frac_key): + return f"{base_key}__frac__{frac_key}" diff --git a/python/fate/arch/tensor/_parties.py b/python/fate/arch/tensor/_parties.py new file mode 100644 index 0000000000..ffe0343c22 --- /dev/null +++ b/python/fate/arch/tensor/_parties.py @@ -0,0 +1,84 @@ +import enum +from typing import List + +from fate.arch.common import Party +from fate.arch.session import get_parties + + +class Parties: + def __init__(self, flag) -> None: + self._flag = flag + + def contains_hosts(self) -> bool: + return bool(self._flag & 1) + + def contains_arbiter(self) -> bool: + return bool(self._flag & 2) + + def contains_guest(self) -> bool: + return bool(self._flag & 4) + + def contains_host(self) -> bool: + return bool(self._flag & 8) + + @property + def indexes(self) -> List[int]: + return [i for i, e in enumerate(bin(self._flag)[::-1]) if e == "1"] + + @classmethod + def get_name(cls, i): + if i < 4: + return {0: "HOSTS", 1: "ARBITER", 2: "GUEST", 3: "HOST"}[i] + else: + return f"HOST{i-3}" + + def __or__(self, other): + return Parties(self._flag | other._flag) + + def __ror__(self, other): + return Parties(self._flag | other._flag) + + def __hash__(self) -> int: + return self._flag + + def __eq__(self, o) -> bool: + return self._flag == o._flag + + def __str__(self): + readable = "|".join([self.get_name(i) for i in self.indexes]) + return f"4b}): {readable}>" + + def __repr__(self): + return self.__str__() + + def __getitem__(self, key): + if self._flag == 1 and isinstance(key, int) and key >= 0: + return Parties(1 << (key + 3)) + raise TypeError("not subscriptable") + + def _get_role_parties(self, role: str): + return get_parties().roles_to_parties([role], strict=False) + + def get_parties(self) -> List[Party]: + parties = [] + if self._flag & 2: + parties.extend(self._get_role_parties("arbiter")) + if self._flag & 4: + parties.extend(self._get_role_parties("guest")) + if self._flag & 1: + parties.extend(self._get_role_parties("host")) + else: + host_bit_int = self._flag >> 3 + if host_bit_int: + hosts = self._get_role_parties("host") + for i, e in enumerate(bin(host_bit_int)[::-1]): + if e == "1": + parties.append(hosts[i]) + return parties + + +class PreludeParty(Parties, enum.Flag): + HOSTS = 1 + ARBITER = 2 + GUEST = 4 + HOST = 8 diff --git a/python/fate/arch/tensor/_tensor.py b/python/fate/arch/tensor/_tensor.py new file mode 100644 index 0000000000..b02890c13b --- /dev/null +++ b/python/fate/arch/tensor/_tensor.py @@ -0,0 +1,508 @@ +import json +import typing +from contextlib import contextmanager +from enum import Enum +from typing import ( + Any, + Callable, + Generator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, + overload, +) + +import torch +from typing_extensions import Literal + +from fate.arch.common import Party +from fate.arch.federation.transfer_variable import IterationGC +from fate.arch.session import get_session + +from ..federation._parties import Parties, PreludeParty +from ._federation import FederationDeserializer +from .abc.tensor import PHEDecryptorABC, PHEEncryptorABC, PHETensorABC + + +class NamespaceState: + def __init__(self, namespace) -> None: + self._namespace = namespace + + def get_namespce(self) -> str: + return self._namespace + + def sub_namespace(self, namespace): + return f"{self._namespace}.{namespace}" + + +class FitState(NamespaceState): + ... + + +class PredictState(NamespaceState): + ... + + +class IterationState(NamespaceState): + ... + + +class CipherKind(Enum): + PHE = 1 + PHE_PAILLIER = 2 + + +class Device(Enum): + CPU = 1 + GPU = 2 + FPGA = 3 + CPU_Intel = 4 + + +class Distributed(Enum): + NONE = 1 + EGGROLL = 2 + SPARK = 3 + + +T = TypeVar("T") + + +class _ContextInside: + def __init__(self, cpn_input) -> None: + self._push_gc_dict = {} + self._pull_gc_dict = {} + + self._flowid = None + + self._roles = cpn_input.roles + self._job_parameters = cpn_input.job_parameters + self._parameters = cpn_input.parameters + self._flow_feeded_parameters = cpn_input.flow_feeded_parameters + + self._device = Device.CPU + self._distributed = Distributed.EGGROLL + + @property + def device(self): + return self._device + + @property + def is_guest(self): + return self._roles["local"]["role"] == "guest" + + @property + def is_host(self): + return self._roles["local"]["role"] == "host" + + @property + def is_arbiter(self): + return self._roles["local"]["role"] == "arbiter" + + @property + def party(self): + role = self._roles["local"]["role"] + party_id = self._roles["local"]["party_id"] + return Party(role, party_id) + + def get_or_set_push_gc(self, key): + if key not in self._push_gc_dict: + self._push_gc_dict[key] = IterationGC() + return self._push_gc_dict[key] + + def get_or_set_pull_gc(self, key): + if key not in self._push_gc_dict: + self._pull_gc_dict[key] = IterationGC() + return self._pull_gc_dict[key] + + def describe(self): + return dict( + party=self.party, + job_parameters=self._job_parameters, + parameters=self._parameters, + flow_feeded_parameters=self._flow_feeded_parameters, + ) + + +class Context: + def __init__(self, inside: _ContextInside, namespace: str) -> None: + self._inside = inside + self._namespace_state = NamespaceState(namespace) + + @classmethod + def from_cpn_input(cls, cpn_input): + states = _ContextInside(cpn_input) + namespace = "fate" + return Context(states, namespace) + + def describe(self): + return json.dumps( + dict( + states=self._inside.describe(), + ) + ) + + @property + def party(self): + return self._inside.party + + @property + def role(self): + return self.party.role + + @property + def party_id(self): + return self.party.party_id + + @property + def is_guest(self): + return self._inside.is_guest + + @property + def is_host(self): + return self._inside.is_guest + + @property + def is_arbiter(self): + return self._inside.is_guest + + @property + def device(self) -> Device: + return self._inside.device + + @property + def distributed(self) -> Distributed: + return self._inside._distributed + + def current_namespace(self): + return self._namespace_state.get_namespce() + + @overload + def keygen( + self, kind: Literal[CipherKind.PHE], key_length: int + ) -> Tuple["PHEEncryptor", "PHEDecryptor"]: + ... + + @overload + def keygen(self, kind: CipherKind, **kwargs) -> Any: + ... + + def keygen(self, kind, key_length: int, **kwargs): + # TODO: exploring expansion eechanisms + if kind == CipherKind.PHE or kind == CipherKind.PHE_PAILLIER: + if self.distributed == Distributed.NONE: + if self.device == Device.CPU: + from .impl.tensor.multithread_cpu_tensor import ( + PaillierPHECipherLocal, + ) + + encryptor, decryptor = PaillierPHECipherLocal().keygen( + key_length=key_length + ) + return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + if self.distributed == Distributed.EGGROLL: + if self.device == Device.CPU: + from .impl.tensor.distributed import PaillierPHECipherDistributed + + encryptor, decryptor = PaillierPHECipherDistributed().keygen( + key_length=key_length + ) + return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + + raise NotImplementedError( + f"keygen for kind<{kind}>-distributed<{self.distributed}>-device<{self.device}> is not implemented" + ) + + def random_tensor(self, shape, num_partition=1) -> "FPTensor": + if self.distributed == Distributed.NONE: + return FPTensor(self, torch.rand(shape)) + else: + from fate.arch.tensor.impl.tensor.distributed import FPTensorDistributed + + from ..session import computing_session + + parts = [] + first_dim_approx = shape[0] // num_partition + last_part_first_dim = shape[0] - (num_partition - 1) * first_dim_approx + assert first_dim_approx > 0 + for i in range(num_partition): + if i == num_partition - 1: + parts.append( + torch.rand( + ( + last_part_first_dim, + *shape[1:], + ) + ) + ) + else: + parts.append(torch.rand((first_dim_approx, *shape[1:]))) + return FPTensor( + self, + FPTensorDistributed( + computing_session.parallelize( + parts, include_key=False, partition=num_partition + ) + ), + ) + + def create_tensor(self, tensor: torch.Tensor) -> "FPTensor": + + return FPTensor(self, tensor) + + @contextmanager + def sub_namespace(self, namespace): + """ + into sub_namespace ``, suffix federation namespace with `namespace` + + Examples: + ``` + with ctx.sub_namespace("fit"): + ctx.push(..., trans_key, obj) + + with ctx.sub_namespace("predict"): + ctx.push(..., trans_key, obj2) + ``` + `obj1` and `obj2` are pushed with different namespace + without conflic. + """ + + prev_namespace_state = self._namespace_state + + # into subnamespace + self._namespace_state = NamespaceState( + self._namespace_state.sub_namespace(namespace) + ) + + # return sub_ctx + # ```python + # with ctx.sub_namespace(xxx) as sub_ctx: + # ... + # ``` + # + yield self + + # restore namespace state when leaving with context + self._namespace_state = prev_namespace_state + + @overload + @contextmanager + def iter_namespaces( + self, start: int, stop: int, *, prefix_name="" + ) -> Generator[Generator["Context", None, None], None, None]: + ... + + @overload + @contextmanager + def iter_namespaces( + self, stop: int, *, prefix_name="" + ) -> Generator[Generator["Context", None, None], None, None]: + ... + + @contextmanager + def iter_namespaces(self, *args, prefix_name=""): + assert 0 < len(args) <= 2, "position argument should be 1 or 2" + if len(args) == 1: + start, stop = 0, args[0] + if len(args) == 2: + start, stop = args[0], args[1] + + prev_namespace_state = self._namespace_state + + def _state_iterator() -> Generator[Context, None, None]: + for i in range(start, stop): + # the tags in the iteration need to be distinguishable + template_formated = f"{prefix_name}iter_{i}" + self._namespace_state = IterationState( + prev_namespace_state.sub_namespace(template_formated) + ) + yield self + + # with context returns iterator of Contexts + # namespaec state inside context is changed alone with iterator comsued + yield _state_iterator() + + # restore namespace state when leaving with context + self._namespace_state = prev_namespace_state + + +class PHEEncryptor: + def __init__(self, encryptor: PHEEncryptorABC) -> None: + self._encryptor = encryptor + + def encrypt(self, tensor: "FPTensor"): + + return PHETensor(tensor._ctx, self._encryptor.encrypt(tensor._tensor)) + + +class PHEDecryptor: + def __init__(self, decryptor: PHEDecryptorABC) -> None: + self._decryptor = decryptor + + def decrypt(self, tensor: "PHETensor") -> "FPTensor": + + return FPTensor(tensor._ctx, self._decryptor.decrypt(tensor._tensor)) + + +class FPTensor: + def __init__(self, ctx: Context, tensor) -> None: + self._ctx = ctx + self._tensor = tensor + + @property + def shape(self): + return self._tensor.shape + + def __add__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__add__"): + return NotImplemented + return self._binary_op(other, self._tensor.__add__) + + def __radd__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__radd__"): + return self.__add__(other) + return self._binary_op(other, self._tensor.__add__) + + def __sub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__sub__"): + return NotImplemented + return self._binary_op(other, self._tensor.__sub__) + + def __rsub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__rsub__"): + return self.__mul__(-1).__add__(other) + return self._binary_op(other, self._tensor.__rsub__) + + def __mul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__mul__"): + return NotImplemented + return self._binary_op(other, self._tensor.__mul__) + + def __rmul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__rmul__"): + return self.__mul__(other) + return self._binary_op(other, self._tensor.__rmul__) + + def __matmul__(self, other: "FPTensor") -> "FPTensor": + if not hasattr(self._tensor, "__matmul__"): + return NotImplemented + if isinstance(other, FPTensor): + return FPTensor(self._ctx, self._tensor.__matmul__(other._tensor)) + else: + return NotImplemented + + def __rmatmul__(self, other: "FPTensor") -> "FPTensor": + if not hasattr(self._tensor, "__rmatmul__"): + return NotImplemented + if isinstance(other, FPTensor): + return FPTensor(self._ctx, self._tensor.__rmatmul__(other._tensor)) + else: + return NotImplemented + + def _binary_op(self, other, func): + if isinstance(other, FPTensor): + return FPTensor(self._ctx, func(other._tensor)) + elif isinstance(other, (int, float)): + return FPTensor(self._ctx, func(other)) + else: + return NotImplemented + + @property + def T(self): + return FPTensor(self._ctx, self._tensor.T) + + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._tensor) + + +class PHETensor: + def __init__(self, ctx: Context, tensor: PHETensorABC) -> None: + self._tensor = tensor + self._ctx = ctx + + @property + def shape(self): + return self._tensor.shape + + def __add__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__add__) + + def __radd__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__radd__) + + def __sub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__sub__) + + def __rsub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__rsub__) + + def __mul__(self, other: Union[FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__mul__) + + def __rmul__(self, other: Union[FPTensor, int, float]) -> "PHETensor": + return self._binary_op(other, self._tensor.__rmul__) + + def __matmul__(self, other: FPTensor) -> "PHETensor": + if isinstance(other, FPTensor): + return PHETensor(self._ctx, self._tensor.__matmul__(other._tensor)) + else: + return NotImplemented + + def __rmatmul__(self, other: FPTensor) -> "PHETensor": + if isinstance(other, FPTensor): + return PHETensor(self._ctx, self._tensor.__rmatmul__(other._tensor)) + else: + return NotImplemented + + def T(self) -> "PHETensor": + return PHETensor(self._ctx, self._tensor.T()) + + @overload + def decrypt(self, decryptor: "PHEDecryptor") -> FPTensor: + ... + + @overload + def decrypt(self, decryptor) -> Any: + ... + + def decrypt(self, decryptor): + return decryptor.decrypt(self) + + def _binary_op(self, other, func): + if isinstance(other, (PHETensor, FPTensor)): + return PHETensor(self._ctx, func(other._tensor)) + elif isinstance(other, (int, float)): + return PHETensor(self._ctx, func(other)) + return NotImplemented + + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._tensor) + + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "tensor") + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensor: + tensor = ctx._pull([party], self.table_key)[0] + return PHETensor(ctx, tensor) + + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "tensor") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensor: + tensor = ctx._pull([party], self.table_key)[0] + return FPTensor(ctx, tensor) diff --git a/python/fate/arch/tensor/abc/block.py b/python/fate/arch/tensor/abc/block.py new file mode 100644 index 0000000000..5741011dda --- /dev/null +++ b/python/fate/arch/tensor/abc/block.py @@ -0,0 +1,119 @@ +import abc +import typing + + +class FPBlockABC: + @classmethod + def zeors(cls, shape) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __add__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __radd__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __sub__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __rsub__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __mul__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __rmul__(self, other: typing.Union["FPBlockABC", float, int]) -> "FPBlockABC": + ... + + @abc.abstractmethod + def __matmul__(self, other: "FPBlockABC") -> "FPBlockABC": + ... + + @abc.abstractmethod + def __rmatmul__(self, other: "FPBlockABC") -> "FPBlockABC": + ... + + +class PHEBlockABC: + """Tensor implements Partial Homomorphic Encryption schema: + 1. decrypt(encrypt(a) + encrypt(b)) = a + b + 2. decrypt(encrypt(a) * b) = a * b + """ + + @abc.abstractmethod + def __add__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __radd__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __sub__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __rsub__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __mul__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __rmul__( + self, other: typing.Union["PHEBlockABC", "FPBlockABC", float, int] + ) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __matmul__(self, other: FPBlockABC) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def __rmatmul__(self, other: FPBlockABC) -> "PHEBlockABC": + ... + + @abc.abstractmethod + def serialize(self): + ... + + # @abc.abstractmethod + def T(self) -> "PHEBlockABC": + ... + + +class PHEBlockEncryptorABC: + @abc.abstractmethod + def encrypt(self, tensor: FPBlockABC) -> PHEBlockABC: + ... + + +class PHEBlockDecryptorABC: + @abc.abstractmethod + def decrypt(self, tensor: PHEBlockABC) -> FPBlockABC: + ... + + +class PHEBlockCipherABC: + @abc.abstractclassmethod + def keygen( + cls, **kwargs + ) -> typing.Tuple[PHEBlockEncryptorABC, PHEBlockDecryptorABC]: + ... diff --git a/python/fate/arch/tensor/abc/tensor.py b/python/fate/arch/tensor/abc/tensor.py new file mode 100644 index 0000000000..a612ac4b78 --- /dev/null +++ b/python/fate/arch/tensor/abc/tensor.py @@ -0,0 +1,118 @@ +import abc +import typing +from typing import Protocol + + +class FPTensorProtocol(Protocol): + def __add__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __radd__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __sub__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __rsub__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __mul__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __rmul__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": + ... + + def __matmul__(self, other: "FPTensorProtocol") -> "FPTensorProtocol": + ... + + def __rmatmul__(self, other: "FPTensorProtocol") -> "FPTensorProtocol": + ... + + +class PHETensorABC(abc.ABC): + """Tensor implements Partial Homomorphic Encryption schema: + 1. decrypt(encrypt(a) + encrypt(b)) = a + b + 2. decrypt(encrypt(a) * b) = a * b + """ + + @abc.abstractmethod + def __add__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __radd__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __sub__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __rsub__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __mul__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __rmul__( + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] + ) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __matmul__(self, other: FPTensorProtocol) -> "PHETensorABC": + ... + + @abc.abstractmethod + def __rmatmul__(self, other: FPTensorProtocol) -> "PHETensorABC": + ... + + @abc.abstractmethod + def serialize(self): + ... + + @abc.abstractmethod + def T(self) -> "PHETensorABC": + ... + + +class PHEEncryptorABC(abc.ABC): + @abc.abstractmethod + def encrypt(self, tensor: FPTensorProtocol) -> PHETensorABC: + ... + + +class PHEDecryptorABC(abc.ABC): + @abc.abstractmethod + def decrypt(self, tensor: PHETensorABC) -> FPTensorProtocol: + ... + + +class PHECipherABC(abc.ABC): + @abc.abstractclassmethod + def keygen(cls, **kwargs) -> typing.Tuple[PHEEncryptorABC, PHEDecryptorABC]: + ... diff --git a/python/fate/arch/tensor/functional.py b/python/fate/arch/tensor/functional.py new file mode 100644 index 0000000000..332f5d2c35 --- /dev/null +++ b/python/fate/arch/tensor/functional.py @@ -0,0 +1,5 @@ +from ._tensor import FPTensor, PHETensor + + +def weighted_mean(X: FPTensor, d: PHETensor) -> PHETensor: + ... diff --git a/python/fate/arch/tensor/impl/__init__.py b/python/fate/arch/tensor/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/arch/tensor/impl/blocks/_metaclass.py b/python/fate/arch/tensor/impl/blocks/_metaclass.py new file mode 100644 index 0000000000..eee79ceb78 --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/_metaclass.py @@ -0,0 +1,365 @@ +import pickle + +import numpy as np +import torch + +from ...abc.block import ( + PHEBlockABC, + PHEBlockCipherABC, + PHEBlockDecryptorABC, + PHEBlockEncryptorABC, +) + + +def _impl_ops(class_obj, method_name, ops): + def func(self, other): + cb = ops(self._cb, other, class_obj) + if cb is NotImplemented: + return NotImplemented + else: + return class_obj(cb) + + func.__name__ = method_name + return func + + +def _impl_init(): + def __init__(self, cb): + self._cb = cb + + return __init__ + + +def _impl_encryptor_init(): + def __init__(self, pk): + self._pk = pk + + return __init__ + + +def _impl_decryptor_init(): + def __init__(self, sk): + self._sk = sk + + return __init__ + + +def _impl_encrypt(pheblock_cls, fpbloke_cls, encrypt_op): + def encrypt(self, other) -> pheblock_cls: + if isinstance(other, fpbloke_cls): + return pheblock_cls(encrypt_op(self._pk, other.numpy())) + + raise NotImplementedError(f"type {other} not supported") + + return encrypt + + +def _impl_decrypt(pheblock_cls, fpbloke_cls, decrypt_op): + def decrypt(self, other, dtype=np.float64) -> fpbloke_cls: + if isinstance(other, pheblock_cls): + return torch.from_numpy(decrypt_op(self._sk, other._cb, dtype)) + raise NotImplementedError(f"type {other} not supported") + + return decrypt + + +def _impl_serialize(): + def serialize(self) -> bytes: + return pickle.dumps(self._cb) + + return serialize + + +def _impl_keygen(encrypt_cls, decrypt_cls, keygen_op): + @classmethod + def keygen(cls, key_length=1024): + pk, sk = keygen_op(bit_size=key_length) + return (encrypt_cls(pk), decrypt_cls(sk)) + + return keygen + + +def _maybe_setattr(obj, name, value): + if not hasattr(obj, name): + setattr(obj, name, value) + + +def phe_keygen_metaclass(encrypt_cls, decrypt_cls, keygen_op): + class PHEKeygenMetaclass(type): + def __new__(cls, name, bases, dict): + keygen_cls = super().__new__(cls, name, bases, dict) + + setattr( + keygen_cls, "keygen", _impl_keygen(encrypt_cls, decrypt_cls, keygen_op) + ) + return keygen_cls + + return PHEKeygenMetaclass + + +def phe_decryptor_metaclass(pheblock_cls, fpblock_cls): + class PHEDecryptorMetaclass(type): + def __new__(cls, name, bases, dict): + decryptor_cls = super().__new__(cls, name, bases, dict) + + setattr(decryptor_cls, "__init__", _impl_decryptor_init()) + setattr( + decryptor_cls, + "decrypt", + _impl_decrypt( + pheblock_cls, fpblock_cls, PHEDecryptorMetaclass._decrypt_numpy + ), + ) + return decryptor_cls + + @staticmethod + def _decrypt_numpy(sk, cb, dtype): + if dtype == np.float64: + return sk.decrypt_f64(cb) + if dtype == np.float32: + return sk.decrypt_f32(cb) + if dtype == np.int64: + return sk.decrypt_i64(cb) + if dtype == np.int32: + return sk.decrypt_i32(cb) + raise NotImplementedError("dtype = {dtype}") + + return PHEDecryptorMetaclass + + +def phe_encryptor_metaclass(pheblock_cls, fpblock_cls): + class PHEEncryptorMetaclass(type): + def __new__(cls, name, bases, dict): + encryptor_cls = super().__new__(cls, name, bases, dict) + + setattr(encryptor_cls, "__init__", _impl_encryptor_init()) + setattr( + encryptor_cls, + "encrypt", + _impl_encrypt( + pheblock_cls, fpblock_cls, PHEEncryptorMetaclass._encrypt_numpy + ), + ) + return encryptor_cls + + @staticmethod + def _encrypt_numpy(pk, other): + if is_ndarray(other): + if is_nd_float64(other): + return pk.encrypt_f64(other) + if is_nd_float32(other): + return pk.encrypt_f32(other) + if is_nd_int64(other): + return pk.encrypt_i64(other) + if is_nd_int32(other): + return pk.encrypt_i32(other) + raise NotImplementedError(f"type {other} {other.dtype} not supported") + + return PHEEncryptorMetaclass + + +class PHEBlockMetaclass(type): + def __new__(cls, name, bases, dict): + class_obj = super().__new__(cls, name, bases, dict) + + setattr(class_obj, "__init__", _impl_init()) + + @property + def shape(self): + return self._cb.shape + + setattr(class_obj, "shape", shape) + _maybe_setattr(class_obj, "serialize", _impl_serialize()) + for impl_name, ops in { + "__add__": PHEBlockMetaclass._add, + "__radd__": PHEBlockMetaclass._radd, + "__sub__": PHEBlockMetaclass._sub, + "__rsub__": PHEBlockMetaclass._rsub, + "__mul__": PHEBlockMetaclass._mul, + "__rmul__": PHEBlockMetaclass._rmul, + "__matmul__": PHEBlockMetaclass._matmul, + "__rmatmul__": PHEBlockMetaclass._rmatmul, + }.items(): + _maybe_setattr(class_obj, impl_name, _impl_ops(class_obj, impl_name, ops)) + + return class_obj + + @staticmethod + def _rmatmul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if isinstance(other, np.ndarray): + if len(other.shape) == 2: + if is_nd_float64(other): + return cb.rmatmul_plaintext_ix2_f64(other) + if is_nd_float32(other): + return cb.rmatmul_plaintext_ix2_f32(other) + if is_nd_int64(other): + return cb.rmatmul_plaintext_ix2_i64(other) + if is_nd_int32(other): + return cb.rmatmul_plaintext_ix2_i32(other) + if len(other.shape) == 1: + if is_nd_float64(other): + return cb.rmatmul_plaintext_ix1_f64(other) + if is_nd_float32(other): + return cb.rmatmul_plaintext_ix1_f32(other) + if is_nd_int64(other): + return cb.rmatmul_plaintext_ix1_i64(other) + if is_nd_int32(other): + return cb.rmatmul_plaintext_ix1_i32(other) + return NotImplemented + + @staticmethod + def _matmul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if len(other.shape) == 2: + if is_nd_float64(other): + return cb.matmul_plaintext_ix2_f64(other) + if is_nd_float32(other): + return cb.matmul_plaintext_ix2_f32(other) + if is_nd_int64(other): + return cb.matmul_plaintext_ix2_i64(other) + if is_nd_int32(other): + return cb.matmul_plaintext_ix2_i32(other) + if len(other.shape) == 1: + if is_nd_float64(other): + return cb.matmul_plaintext_ix1_f64(other) + if is_nd_float32(other): + return cb.matmul_plaintext_ix1_f32(other) + if is_nd_int64(other): + return cb.matmul_plaintext_ix1_i64(other) + if is_nd_int32(other): + return cb.matmul_plaintext_ix1_i32(other) + return NotImplemented + + @staticmethod + def _mul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.mul_plaintext_f64(other) + if is_nd_float32(other): + return cb.mul_plaintext_f32(other) + if is_nd_int64(other): + return cb.mul_plaintext_i64(other) + if is_nd_int32(other): + return cb.mul_plaintext_i32(other) + raise NotImplemented + if is_float(other): + return cb.mul_plaintext_scalar_f64(other) + if is_float32(other): + return cb.mul_plaintext_scalar_f32(other) + if is_int(other): + return cb.mul_plaintext_scalar_i64(other) + if is_int32(other): + return cb.mul_plaintext_scalar_i32(other) + return NotImplemented + + @staticmethod + def _sub(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.sub_plaintext_f64(other) + if is_nd_float32(other): + return cb.sub_plaintext_f32(other) + if is_nd_int64(other): + return cb.sub_plaintext_i64(other) + if is_nd_int32(other): + return cb.sub_plaintext_i32(other) + return NotImplemented + + if isinstance(other, class_obj): + return cb.sub_cipherblock(other._cb) + if is_float(other): + return cb.sub_plaintext_scalar_f64(other) + if is_float32(other): + return cb.sub_plaintext_scalar_f32(other) + if is_int(other): + return cb.sub_plaintext_scalar_i64(other) + if is_int32(other): + return cb.sub_plaintext_scalar_i32(other) + + return NotImplemented + + @staticmethod + def _add(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.add_plaintext_f64(other) + if is_nd_float32(other): + return cb.add_plaintext_f32(other) + if is_nd_int64(other): + return cb.add_plaintext_i64(other) + if is_nd_int32(other): + return cb.add_plaintext_i32(other) + return NotImplemented + + if isinstance(other, class_obj): + return cb.add_cipherblock(other._cb) + if is_float(other): + return cb.add_plaintext_scalar_f64(other) + if is_float32(other): + return cb.add_plaintext_scalar_f32(other) + if is_int(other): + return cb.add_plaintext_scalar_i64(other) + if is_int32(other): + return cb.add_plaintext_scalar_i32(other) + + return NotImplemented + + @staticmethod + def _radd(cb, other, class_obj): + return PHEBlockMetaclass._add(cb, other, class_obj) + + @staticmethod + def _rsub(cb, other, class_obj): + return PHEBlockMetaclass._add( + PHEBlockMetaclass._mul(cb, -1, class_obj), other, class_obj + ) + + @staticmethod + def _rmul(cb, other, class_obj): + return PHEBlockMetaclass._mul(cb, other, class_obj) + + +def is_ndarray(v): + return isinstance(v, np.ndarray) + + +def is_float(v): + return isinstance(v, (float, np.float64)) + + +def is_float32(v): + return isinstance(v, np.float32) + + +def is_int(v): + return isinstance(v, (int, np.int64)) + + +def is_int32(v): + return isinstance(v, np.int32) + + +def is_nd_float64(v): + return v.dtype == np.float64 + + +def is_nd_float32(v): + return v.dtype == np.float32 + + +def is_nd_int64(v): + return v.dtype == np.int64 + + +def is_nd_int32(v): + return v.dtype == np.int32 diff --git a/python/fate/arch/tensor/impl/blocks/cpu_paillier_block.py b/python/fate/arch/tensor/impl/blocks/cpu_paillier_block.py new file mode 100644 index 0000000000..d592762430 --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/cpu_paillier_block.py @@ -0,0 +1,50 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import rust_paillier +import torch + +from ._metaclass import ( + PHEBlockMetaclass, + phe_decryptor_metaclass, + phe_encryptor_metaclass, + phe_keygen_metaclass, +) + + +class PaillierBlock(metaclass=PHEBlockMetaclass): + pass + + +class BlockPaillierEncryptor( + metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierDecryptor( + metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierCipher( + metaclass=phe_keygen_metaclass( + BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.keygen + ) +): + pass diff --git a/python/fate/arch/tensor/impl/blocks/multithread_cpu_paillier_block.py b/python/fate/arch/tensor/impl/blocks/multithread_cpu_paillier_block.py new file mode 100644 index 0000000000..1e89f92e5c --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/multithread_cpu_paillier_block.py @@ -0,0 +1,50 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import rust_paillier.par +import torch + +from ._metaclass import ( + PHEBlockMetaclass, + phe_decryptor_metaclass, + phe_encryptor_metaclass, + phe_keygen_metaclass, +) + + +class PaillierBlock(metaclass=PHEBlockMetaclass): + pass + + +class BlockPaillierEncryptor( + metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierDecryptor( + metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierCipher( + metaclass=phe_keygen_metaclass( + BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.par.keygen + ) +): + pass diff --git a/python/fate/arch/tensor/impl/blocks/python_paillier_block/__init__.py b/python/fate/arch/tensor/impl/blocks/python_paillier_block/__init__.py new file mode 100644 index 0000000000..66e5fdfd25 --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/python_paillier_block/__init__.py @@ -0,0 +1,7 @@ +from ._python_paillier_block import ( + BlockPaillierCipher, + BlockPaillierDecryptor, + BlockPaillierEncryptor, +) + +__all__ = ["BlockPaillierCipher", "BlockPaillierEncryptor", "BlockPaillierDecryptor"] diff --git a/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fate_paillier.py b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fate_paillier.py new file mode 100644 index 0000000000..04f0fdf895 --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fate_paillier.py @@ -0,0 +1,350 @@ +"""Paillier encryption library for partially homomorphic encryption.""" + +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import random + +from ._fixedpoint import FixedPointNumber +from ._gmpy_math import getprimeover, invert, mpz, powmod + + +class PaillierKeypair(object): + def __init__(self): + pass + + @staticmethod + def generate_keypair(n_length=1024): + """return a new :class:`PaillierPublicKey` and :class:`PaillierPrivateKey`.""" + p = q = n = None + n_len = 0 + + while n_len != n_length: + p = getprimeover(n_length // 2) + q = p + while q == p: + q = getprimeover(n_length // 2) + n = p * q + n_len = n.bit_length() + + public_key = PaillierPublicKey(n) + private_key = PaillierPrivateKey(public_key, p, q) + + return public_key, private_key + + +class PaillierPublicKey(object): + """Contains a public key and associated encryption methods.""" + + def __init__(self, n): + self.g = n + 1 + self.n = n + self.nsquare = n * n + self.max_int = n // 3 - 1 + + def __repr__(self): + hashcode = hex(hash(self))[2:] + return "".format(hashcode[:10]) + + def __eq__(self, other): + return self.n == other.n + + def __hash__(self): + return hash(self.n) + + def apply_obfuscator(self, ciphertext, random_value=None): + """ """ + r = random_value or random.SystemRandom().randrange(1, self.n) + obfuscator = powmod(r, self.n, self.nsquare) + + return (ciphertext * obfuscator) % self.nsquare + + def raw_encrypt(self, plaintext, random_value=None): + """ """ + if not isinstance(plaintext, int): + raise TypeError("plaintext should be int, but got: %s" % type(plaintext)) + + if plaintext >= (self.n - self.max_int) and plaintext < self.n: + # Very large plaintext, take a sneaky shortcut using inverses + neg_plaintext = self.n - plaintext # = abs(plaintext - nsquare) + neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare + ciphertext = invert(neg_ciphertext, self.nsquare) + else: + ciphertext = (self.n * plaintext + 1) % self.nsquare + + ciphertext = self.apply_obfuscator(ciphertext, random_value) + + return ciphertext + + def encrypt(self, value, precision=None, random_value=None): + """Encode and Paillier encrypt a real number value.""" + if isinstance(value, FixedPointNumber): + value = value.decode() + encoding = FixedPointNumber.encode(value, self.n, self.max_int, precision) + obfuscator = random_value or 1 + ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator) + encryptednumber = PaillierEncryptedNumber(self, ciphertext, encoding.exponent) + if random_value is None: + encryptednumber.apply_obfuscator() + + return encryptednumber + + +class PaillierPrivateKey(object): + """Contains a private key and associated decryption method.""" + + def __init__(self, public_key, p, q): + if not p * q == public_key.n: + raise ValueError("given public key does not match the given p and q") + if p == q: + raise ValueError("p and q have to be different") + self.public_key = public_key + if q < p: + self.p = q + self.q = p + else: + self.p = p + self.q = q + self.psquare = self.p * self.p + self.qsquare = self.q * self.q + self.q_inverse = invert(self.q, self.p) + self.hp = self.h_func(self.p, self.psquare) + self.hq = self.h_func(self.q, self.qsquare) + + def __eq__(self, other): + return self.p == other.p and self.q == other.q + + def __hash__(self): + return hash((self.p, self.q)) + + def __repr__(self): + hashcode = hex(hash(self))[2:] + + return "".format(hashcode[:10]) + + def h_func(self, x, xsquare): + """Computes the h-function as defined in Paillier's paper page.""" + return invert(self.l_func(powmod(self.public_key.g, x - 1, xsquare), x), x) + + def l_func(self, x, p): + """computes the L function as defined in Paillier's paper.""" + + return (x - 1) // p + + def crt(self, mp, mq): + """the Chinese Remainder Theorem as needed for decryption. + return the solution modulo n=pq. + """ + u = (mp - mq) * self.q_inverse % self.p + x = (mq + (u * self.q)) % self.public_key.n + + return x + + def raw_decrypt(self, ciphertext): + """return raw plaintext.""" + if not isinstance(ciphertext, int): + raise TypeError("ciphertext should be an int, not: %s" % type(ciphertext)) + + mp = ( + self.l_func(powmod(ciphertext, self.p - 1, self.psquare), self.p) + * self.hp + % self.p + ) + + mq = ( + self.l_func(powmod(ciphertext, self.q - 1, self.qsquare), self.q) + * self.hq + % self.q + ) + + return self.crt(mp, mq) + + def decrypt(self, encrypted_number): + """return the decrypted & decoded plaintext of encrypted_number.""" + if not isinstance(encrypted_number, PaillierEncryptedNumber): + raise TypeError( + "encrypted_number should be an PaillierEncryptedNumber, \ + not: %s" + % type(encrypted_number) + ) + + if self.public_key != encrypted_number.public_key: + raise ValueError("encrypted_number was encrypted against a different key!") + + encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False)) + encoded = FixedPointNumber( + encoded, + encrypted_number.exponent, + self.public_key.n, + self.public_key.max_int, + ) + decrypt_value = encoded.decode() + + return decrypt_value + + +class PaillierEncryptedNumber(object): + """Represents the Paillier encryption of a float or int.""" + + def __init__(self, public_key, ciphertext, exponent=0): + self.public_key = public_key + self.__ciphertext = ciphertext + self.exponent = exponent + self.__is_obfuscator = False + + if not isinstance(self.__ciphertext, int): + raise TypeError( + "ciphertext should be an int, not: %s" % type(self.__ciphertext) + ) + + if not isinstance(self.public_key, PaillierPublicKey): + raise TypeError( + "public_key should be a PaillierPublicKey, not: %s" + % type(self.public_key) + ) + + def ciphertext(self, be_secure=True): + """return the ciphertext of the PaillierEncryptedNumber.""" + if be_secure and not self.__is_obfuscator: + self.apply_obfuscator() + + return self.__ciphertext + + def apply_obfuscator(self): + """ciphertext by multiplying by r ** n with random r""" + self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext) + self.__is_obfuscator = True + + def __add__(self, other): + if isinstance(other, PaillierEncryptedNumber): + return self.__add_encryptednumber(other) + else: + return self.__add_scalar(other) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + return self + (other * -1) + + def __rsub__(self, other): + return other + (self * -1) + + def __rmul__(self, scalar): + return self.__mul__(scalar) + + def __truediv__(self, scalar): + return self.__mul__(1 / scalar) + + def __mul__(self, scalar): + """return Multiply by an scalar(such as int, float)""" + if isinstance(scalar, FixedPointNumber): + scalar = scalar.decode() + encode = FixedPointNumber.encode( + scalar, self.public_key.n, self.public_key.max_int + ) + plaintext = encode.encoding + + if plaintext < 0 or plaintext >= self.public_key.n: + raise ValueError("Scalar out of bounds: %i" % plaintext) + + if plaintext >= self.public_key.n - self.public_key.max_int: + # Very large plaintext, play a sneaky trick using inverses + neg_c = invert(self.ciphertext(False), self.public_key.nsquare) + neg_scalar = self.public_key.n - plaintext + ciphertext = powmod(neg_c, neg_scalar, self.public_key.nsquare) + else: + ciphertext = powmod( + self.ciphertext(False), plaintext, self.public_key.nsquare + ) + + exponent = self.exponent + encode.exponent + + return PaillierEncryptedNumber(self.public_key, ciphertext, exponent) + + def increase_exponent_to(self, new_exponent): + """return PaillierEncryptedNumber: + new PaillierEncryptedNumber with same value but having great exponent. + """ + if new_exponent < self.exponent: + raise ValueError( + "New exponent %i should be great than old exponent %i" + % (new_exponent, self.exponent) + ) + + factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent) + new_encryptednumber = self.__mul__(factor) + new_encryptednumber.exponent = new_exponent + + return new_encryptednumber + + def __align_exponent(self, x, y): + """return x,y with same exponet""" + if x.exponent < y.exponent: + x = x.increase_exponent_to(y.exponent) + elif x.exponent > y.exponent: + y = y.increase_exponent_to(x.exponent) + + return x, y + + def __add_scalar(self, scalar): + """return PaillierEncryptedNumber: z = E(x) + y""" + if isinstance(scalar, FixedPointNumber): + scalar = scalar.decode() + encoded = FixedPointNumber.encode( + scalar, + self.public_key.n, + self.public_key.max_int, + max_exponent=self.exponent, + ) + return self.__add_fixpointnumber(encoded) + + def __add_fixpointnumber(self, encoded): + """return PaillierEncryptedNumber: z = E(x) + FixedPointNumber(y)""" + if self.public_key.n != encoded.n: + raise ValueError( + "Attempted to add numbers encoded against different public keys!" + ) + + # their exponents must match, and align. + x, y = self.__align_exponent(self, encoded) + + encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1) + encryptednumber = self.__raw_add( + x.ciphertext(False), encrypted_scalar, x.exponent + ) + + return encryptednumber + + def __add_encryptednumber(self, other): + """return PaillierEncryptedNumber: z = E(x) + E(y)""" + if self.public_key != other.public_key: + raise ValueError("add two numbers have different public key!") + + # their exponents must match, and align. + x, y = self.__align_exponent(self, other) + + encryptednumber = self.__raw_add( + x.ciphertext(False), y.ciphertext(False), x.exponent + ) + + return encryptednumber + + def __raw_add(self, e_x, e_y, exponent): + """return the integer E(x + y) given ints E(x) and E(y).""" + ciphertext = mpz(e_x) * mpz(e_y) % self.public_key.nsquare + + return PaillierEncryptedNumber(self.public_key, int(ciphertext), exponent) diff --git a/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fixedpoint.py b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fixedpoint.py new file mode 100644 index 0000000000..e29d888fd8 --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_fixedpoint.py @@ -0,0 +1,384 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import math +import sys + +import numpy as np + + +class FixedPointNumber(object): + """Represents a float or int fixedpoint encoding;.""" + + BASE = 16 + LOG2_BASE = math.log(BASE, 2) + FLOAT_MANTISSA_BITS = sys.float_info.mant_dig + + Q = 293973345475167247070445277780365744413**2 + + def __init__(self, encoding, exponent, n=None, max_int=None): + if n is None: + self.n = FixedPointNumber.Q + self.max_int = self.n // 2 + else: + self.n = n + if max_int is None: + self.max_int = self.n // 2 + else: + self.max_int = max_int + + self.encoding = encoding + self.exponent = exponent + + @classmethod + def calculate_exponent_from_precision(cls, precision): + exponent = math.floor(math.log(precision, cls.BASE)) + return exponent + + @classmethod + def encode(cls, scalar, n=None, max_int=None, precision=None, max_exponent=None): + """return an encoding of an int or float.""" + # Calculate the maximum exponent for desired precision + exponent = None + + # Too low value preprocess; + # avoid "OverflowError: int too large to convert to float" + + if np.abs(scalar) < 1e-200: + scalar = 0 + + if n is None: + n = cls.Q + max_int = n // 2 + + if precision is None: + if ( + isinstance(scalar, int) + or isinstance(scalar, np.int16) + or isinstance(scalar, np.int32) + or isinstance(scalar, np.int64) + ): + exponent = 0 + elif ( + isinstance(scalar, float) + or isinstance(scalar, np.float16) + or isinstance(scalar, np.float32) + or isinstance(scalar, np.float64) + ): + flt_exponent = math.frexp(scalar)[1] + lsb_exponent = cls.FLOAT_MANTISSA_BITS - flt_exponent + exponent = math.floor(lsb_exponent / cls.LOG2_BASE) + else: + raise TypeError("Don't know the precision of type %s." % type(scalar)) + else: + exponent = cls.calculate_exponent_from_precision(precision) + + if max_exponent is not None: + exponent = max(max_exponent, exponent) + + int_fixpoint = int(round(scalar * pow(cls.BASE, exponent))) + + if abs(int_fixpoint) > max_int: + raise ValueError( + f"Integer needs to be within +/- {max_int},but got {int_fixpoint}," + f"basic info, scalar={scalar}, base={cls.BASE}, exponent={exponent}" + ) + + return cls(int_fixpoint % n, exponent, n, max_int) + + def decode(self): + """return decode plaintext.""" + if self.encoding >= self.n: + # Should be mod n + raise ValueError("Attempted to decode corrupted number") + elif self.encoding <= self.max_int: + # Positive + mantissa = self.encoding + elif self.encoding >= self.n - self.max_int: + # Negative + mantissa = self.encoding - self.n + else: + raise OverflowError( + f"Overflow detected in decode number, encoding: {self.encoding}," + f"{self.exponent}" + f" {self.n}" + ) + + return mantissa * pow(self.BASE, -self.exponent) + + def increase_exponent_to(self, new_exponent): + """return FixedPointNumber: new encoding with same value but having great exponent.""" + if new_exponent < self.exponent: + raise ValueError( + "New exponent %i should be greater than" + "old exponent %i" % (new_exponent, self.exponent) + ) + + factor = pow(self.BASE, new_exponent - self.exponent) + new_encoding = self.encoding * factor % self.n + + return FixedPointNumber(new_encoding, new_exponent, self.n, self.max_int) + + def __align_exponent(self, x, y): + """return x,y with same exponent""" + if x.exponent < y.exponent: + x = x.increase_exponent_to(y.exponent) + elif x.exponent > y.exponent: + y = y.increase_exponent_to(x.exponent) + + return x, y + + def __truncate(self, a): + scalar = a.decode() + return FixedPointNumber.encode(scalar, n=self.n, max_int=self.max_int) + + def __add__(self, other): + if isinstance(other, FixedPointNumber): + return self.__add_fixedpointnumber(other) + elif type(other).__name__ == "PaillierEncryptedNumber": + return other + self.decode() + else: + return self.__add_scalar(other) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + if isinstance(other, FixedPointNumber): + return self.__sub_fixedpointnumber(other) + elif type(other).__name__ == "PaillierEncryptedNumber": + return (other - self.decode()) * -1 + else: + return self.__sub_scalar(other) + + def __rsub__(self, other): + if type(other).__name__ == "PaillierEncryptedNumber": + return other - self.decode() + + x = self.__sub__(other) + x = -1 * x.decode() + return self.encode(x, n=self.n, max_int=self.max_int) + + def __rmul__(self, other): + return self.__mul__(other) + + def __mul__(self, other): + if isinstance(other, FixedPointNumber): + return self.__mul_fixedpointnumber(other) + elif type(other).__name__ == "PaillierEncryptedNumber": + return other * self.decode() + else: + return self.__mul_scalar(other) + + def __truediv__(self, other): + if isinstance(other, FixedPointNumber): + scalar = other.decode() + else: + scalar = other + + return self.__mul__(1 / scalar) + + def __rtruediv__(self, other): + res = 1.0 / self.__truediv__(other).decode() + return FixedPointNumber.encode(res, n=self.n, max_int=self.max_int) + + def __lt__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + if x < y: + return True + else: + return False + + def __gt__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + if x > y: + return True + else: + return False + + def __le__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + if x <= y: + return True + else: + return False + + def __ge__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + + if x >= y: + return True + else: + return False + + def __eq__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + if x == y: + return True + else: + return False + + def __ne__(self, other): + x = self.decode() + if isinstance(other, FixedPointNumber): + y = other.decode() + else: + y = other + if x != y: + return True + else: + return False + + def __add_fixedpointnumber(self, other): + if self.n != other.n: + other = self.encode(other.decode(), n=self.n, max_int=self.max_int) + x, y = self.__align_exponent(self, other) + encoding = (x.encoding + y.encoding) % self.n + return FixedPointNumber(encoding, x.exponent, n=self.n, max_int=self.max_int) + + def __add_scalar(self, scalar): + encoded = self.encode(scalar, n=self.n, max_int=self.max_int) + return self.__add_fixedpointnumber(encoded) + + def __sub_fixedpointnumber(self, other): + if self.n != other.n: + other = self.encode(other.decode(), n=self.n, max_int=self.max_int) + x, y = self.__align_exponent(self, other) + encoding = (x.encoding - y.encoding) % self.n + + return FixedPointNumber(encoding, x.exponent, n=self.n, max_int=self.max_int) + + def __sub_scalar(self, scalar): + scalar = -1 * scalar + return self.__add_scalar(scalar) + + def __mul_fixedpointnumber(self, other): + return self.__mul_scalar(other.decode()) + + def __mul_scalar(self, scalar): + val = self.decode() + z = val * scalar + z_encode = FixedPointNumber.encode(z, n=self.n, max_int=self.max_int) + return z_encode + + def __abs__(self): + if self.encoding <= self.max_int: + # Positive + return self + elif self.encoding >= self.n - self.max_int: + # Negative + return self * -1 + + def __mod__(self, other): + return FixedPointNumber( + self.encoding % other, self.exponent, n=self.n, max_int=self.max_int + ) + + +class FixedPointEndec(object): + def __init__(self, n=None, max_int=None, precision=None, *args, **kwargs): + if n is None: + self.n = FixedPointNumber.Q + self.max_int = self.n // 2 + else: + self.n = n + if max_int is None: + self.max_int = self.n // 2 + else: + self.max_int = max_int + + self.precision = precision + + @classmethod + def _transform_op(cls, tensor, op): + from .....session import is_table + + def _transform(x): + arr = np.zeros(shape=x.shape, dtype=object) + view = arr.view().reshape(-1) + x_array = x.view().reshape(-1) + for i in range(arr.size): + view[i] = op(x_array[i]) + + return arr + + if isinstance( + tensor, + ( + int, + np.int16, + np.int32, + np.int64, + float, + np.float16, + np.float32, + np.float64, + FixedPointNumber, + ), + ): + return op(tensor) + + if isinstance(tensor, np.ndarray): + z = _transform(tensor) + return z + + elif is_table(tensor): + f = functools.partial(_transform) + return tensor.mapValues(f) + else: + raise ValueError(f"unsupported type: {type(tensor)}") + + def _encode(self, scalar): + return FixedPointNumber.encode( + scalar, n=self.n, max_int=self.max_int, precision=self.precision + ) + + def _decode(self, number): + return number.decode() + + def _truncate(self, number): + scalar = number.decode() + return FixedPointNumber.encode(scalar, n=self.n, max_int=self.max_int) + + def encode(self, float_tensor): + return self._transform_op(float_tensor, op=self._encode) + + def decode(self, integer_tensor): + return self._transform_op(integer_tensor, op=self._decode) + + def truncate(self, integer_tensor, *args, **kwargs): + return self._transform_op(integer_tensor, op=self._truncate) diff --git a/python/fate/arch/tensor/impl/blocks/python_paillier_block/_gmpy_math.py b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_gmpy_math.py new file mode 100644 index 0000000000..10b1d3604d --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_gmpy_math.py @@ -0,0 +1,134 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import random + +import gmpy2 + +POWMOD_GMP_SIZE = pow(2, 64) + + +def powmod(a, b, c): + """ + return int: (a ** b) % c + """ + + if a == 1: + return 1 + + if max(a, b, c) < POWMOD_GMP_SIZE: + return pow(a, b, c) + + else: + return int(gmpy2.powmod(a, b, c)) + + +def crt_coefficient(p, q): + """ + return crt coefficient + """ + tq = gmpy2.invert(p, q) + tp = gmpy2.invert(q, p) + return tp * q, tq * p + + +def powmod_crt(x, d, n, p, q, cp, cq): + """ + return int: (a ** b) % n + """ + + rp = gmpy2.powmod(x, d % (p - 1), p) + rq = gmpy2.powmod(x, d % (q - 1), q) + return int((rp * cp + rq * cq) % n) + + +def invert(a, b): + """return int: x, where a * x == 1 mod b""" + x = int(gmpy2.invert(a, b)) + + if x == 0: + raise ZeroDivisionError("invert(a, b) no inverse exists") + + return x + + +def getprimeover(n): + """return a random n-bit prime number""" + r = gmpy2.mpz(random.SystemRandom().getrandbits(n)) + r = gmpy2.bit_set(r, n - 1) + + return int(gmpy2.next_prime(r)) + + +def isqrt(n): + """return the integer square root of N""" + + return int(gmpy2.isqrt(n)) + + +def is_prime(n): + """ + true if n is probably a prime, false otherwise + :param n: + :return: + """ + return gmpy2.is_prime(int(n)) + + +def legendre(a, p): + return pow(a, (p - 1) // 2, p) + + +def tonelli(n, p): + # assert legendre(n, p) == 1, "not a square (mod p)" + q = p - 1 + s = 0 + while q % 2 == 0: + q //= 2 + s += 1 + if s == 1: + return pow(n, (p + 1) // 4, p) + for z in range(2, p): + if p - 1 == legendre(z, p): + break + c = pow(z, q, p) + r = pow(n, (q + 1) // 2, p) + t = pow(n, q, p) + m = s + while (t - 1) % p != 0: + t2 = (t * t) % p + for i in range(1, m): + if (t2 - 1) % p == 0: + break + t2 = (t2 * t2) % p + b = pow(c, 1 << (m - i - 1), p) + r = (r * b) % p + c = (b * b) % p + t = (t * c) % p + m = i + return r + + +def gcd(a, b): + return int(gmpy2.gcd(a, b)) + + +def next_prime(n): + return int(gmpy2.next_prime(n)) + + +def mpz(n): + return gmpy2.mpz(n) diff --git a/python/fate/arch/tensor/impl/blocks/python_paillier_block/_python_paillier_block.py b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_python_paillier_block.py new file mode 100644 index 0000000000..4ed1dda8ac --- /dev/null +++ b/python/fate/arch/tensor/impl/blocks/python_paillier_block/_python_paillier_block.py @@ -0,0 +1,206 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import operator +import pickle +import typing + +from torch import Tensor + +from ....abc.block import ( + PHEBlockABC, + PHEBlockCipherABC, + PHEBlockDecryptorABC, + PHEBlockEncryptorABC, +) +from ._fate_paillier import ( + PaillierEncryptedNumber, + PaillierKeypair, + PaillierPrivateKey, + PaillierPublicKey, +) + +# maybe need wrap? +FPBlock = Tensor + +T = typing.TypeVar("T", bound=PaillierEncryptedNumber) + +UnEncryptedNumeric = typing.Union[int, float, FPBlock] +EncryptedNumeric = typing.Union[PaillierEncryptedNumber, "PaillierBlock[T]"] +Numeric = typing.Union[UnEncryptedNumeric, EncryptedNumeric] + + +class PaillierBlock(typing.Generic[T], PHEBlockABC): + """ + use list of list to mimic tensor + """ + + def __init__(self, inner: typing.List[typing.List[T]]) -> None: + self._xsize = len(inner) + self._ysize = len(inner[0]) + self.shape = (self._xsize, self._ysize) + self._inner = inner + + @typing.overload + def __getitem__(self, item: typing.Tuple[int, int]) -> T: + ... + + @typing.overload + def __getitem__(self, item: typing.Tuple[int, slice]) -> "PaillierBlock[T]": + ... + + @typing.overload + def __getitem__(self, item: typing.Tuple[slice, int]) -> "PaillierBlock[T]": + ... + + @typing.overload + def __getitem__(self, item: typing.Tuple[slice, slice]) -> "PaillierBlock[T]": + ... + + @typing.overload + def __getitem__(self, item: slice) -> "PaillierBlock[T]": + ... + + def __getitem__(self, item): + if isinstance(item, tuple) and len(item) == 2: + xind, yind = item + if isinstance(xind, int): + return self._inner[xind][yind] + elif isinstance(xind, slice): + if isinstance(yind, slice): + return PaillierBlock([row[yind] for row in self._inner[xind]]) + if isinstance(yind, int): + return PaillierBlock([[row[yind]] for row in self._inner[xind]]) + elif isinstance(item, slice): + return PaillierBlock(self._inner[item]) + return NotImplemented + + def _binary_op( + self, other, tensor_types: typing.Tuple, scale_types: typing.Tuple, op + ): + if isinstance(other, tensor_types): + assert self.shape == other.shape + out = [] + for i in range(self._xsize): + out.append([]) + for j in range(self._ysize): + out[i][j] = op(self[i, j], other[i, j]) + return PaillierBlock[T](out) + elif isinstance(other, scale_types): + out = [] + for i in range(self._xsize): + out.append([]) + for j in range(self._ysize): + out[i][j] = op(self[i, j], other) + return PaillierBlock[T](out) + return NotImplemented + + def _binary_paillier_not_supported_op(self, other, op): + return self._binary_op(other, (FPBlock,), (int, float), op) + + def _binary_paillier_supported_op(self, other, op): + return self._binary_op( + other, (PaillierBlock, FPBlock), (int, float, PaillierEncryptedNumber), op + ) + + def __add__(self, other: Numeric) -> "PaillierBlock": + return self._binary_paillier_supported_op(other, operator.add) + + def __radd__(self, other: Numeric) -> "PaillierBlock": + return self._binary_paillier_supported_op(other, operator.add) + + def __sub__(self, other: Numeric) -> "PaillierBlock": + return self._binary_paillier_supported_op(other, operator.sub) + + def __rsub__(self, other: Numeric) -> "PaillierBlock": + return self._binary_paillier_supported_op(other, lambda x, y: x.__rsub__(y)) + + def __mul__(self, other: UnEncryptedNumeric) -> "PaillierBlock": + return self._binary_paillier_not_supported_op(other, operator.mul) + + def __rmul__(self, other: UnEncryptedNumeric) -> "PaillierBlock": + return self._binary_paillier_not_supported_op(other, lambda x, y: x.__rmul__(y)) + + def __matmul__(self, other: FPBlock) -> "PaillierBlock": + out = [] + if isinstance(other, FPBlock): + assert self.shape[1] == other.shape[0] + for i in range(self.shape[0]): + out.append([]) + for j in range(other.shape[1]): + c = self[i, 0] * other[0, j] + for k in range(1, self.shape[1]): + c += self[i, k] * other[k, j] + out[i][j] = c + return PaillierBlock(out) + else: + return NotImplemented + + def __rmatmul__(self, other: FPBlock) -> "PaillierBlock": + out = [] + if isinstance(other, FPBlock): + assert self.shape[0] == other.shape[1] + for i in range(other.shape[0]): + out.append([]) + for j in range(self.shape[1]): + c = other[i, 0] * self[0, j] + for k in range(1, other.shape[1]): + c += other[i, k] * self[k, j] + out[i][j] = c + return PaillierBlock(out) + else: + return NotImplemented + + def serialize(self) -> bytes: + return pickle.dumps(self._inner) + + def T(self) -> "PHEBlockABC": + # todo: transpose could be lazy + return PaillierBlock( + [ + [self._inner[x][y] for y in range(self._ysize)] + for x in range(self._xsize) + ] + ) + + +class BlockPaillierEncryptor(PHEBlockEncryptorABC): + def __init__(self, pubkey: PaillierPublicKey) -> None: + self._pubkey = pubkey + + def encrypt(self, tensor: FPBlock) -> PaillierBlock: + return PaillierBlock( + [[self._pubkey.encrypt(x) for x in row] for row in tensor.tolist()], + ) + + +class BlockPaillierDecryptor(PHEBlockDecryptorABC): + def __init__(self, prikey: PaillierPrivateKey) -> None: + self._prikey = prikey + + def decrypt(self, tensor: PaillierBlock) -> FPBlock: + return FPBlock( + [[self._prikey.decrypt(x) for x in row] for row in tensor._inner] + ) + + +class BlockPaillierCipher(PHEBlockCipherABC): + @classmethod + def keygen( + cls, n_length=1024, **kwargs + ) -> typing.Tuple[BlockPaillierEncryptor, BlockPaillierDecryptor]: + pubkey, prikey = PaillierKeypair.generate_keypair(n_length=n_length) + return (BlockPaillierEncryptor(pubkey), BlockPaillierDecryptor(prikey)) diff --git a/python/fate/arch/tensor/impl/tensor/_metaclass.py b/python/fate/arch/tensor/impl/tensor/_metaclass.py new file mode 100644 index 0000000000..7738efd0cf --- /dev/null +++ b/python/fate/arch/tensor/impl/tensor/_metaclass.py @@ -0,0 +1,169 @@ +import typing + +from ...abc.tensor import PHECipherABC, PHEDecryptorABC, PHEEncryptorABC, PHETensorABC + + +class Local: + @property + def block(self): + ... + + def is_distributed(self): + return False + + +def phe_tensor_metaclass(fp_cls): + class PHETensorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_cls = super().__new__(cls, name, (*bases, Local), dict) + + def __init__(self, block) -> None: + self._block = block + self._is_transpose = False + + setattr(phe_cls, "__init__", __init__) + + @property + def shape(self): + return self._block.shape + + setattr(phe_cls, "shape", shape) + + @property + def T(self) -> phe_cls: + transposed = phe_cls(self._block) + transposed._is_transpose = not self._is_transpose + return transposed + + setattr(phe_cls, "T", T) + + def serialize(self) -> bytes: + # todo: impl me + ... + + setattr(phe_cls, "serialize", serialize) + + def __add__(self, other): + if isinstance(other, phe_cls): + other = other._block + + if isinstance(other, (phe_cls, fp_cls)): + return phe_cls(self._block + other) + elif isinstance(other, (int, float)): + return phe_cls(self._block + other) + else: + return NotImplemented + + def __radd__(self, other): + return __add__(other, self) + + setattr(phe_cls, "__add__", __add__) + setattr(phe_cls, "__radd__", __radd__) + + def __sub__(self, other): + if isinstance(other, phe_cls): + other = other._block + + if isinstance(other, (phe_cls, fp_cls)): + return phe_cls(self._block - other) + elif isinstance(other, (int, float)): + return phe_cls(self._block - other) + else: + return NotImplemented + + def __rsub__(self, other): + return __sub__(other, self) + + setattr(phe_cls, "__sub__", __sub__) + setattr(phe_cls, "__rsub__", __rsub__) + + def __mul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(self._block * other) + elif isinstance(other, (int, float)): + return phe_cls(self._block * other) + else: + return NotImplemented + + def __rmul__(self, other): + return __mul__(other, self) + + setattr(phe_cls, "__mul__", __mul__) + setattr(phe_cls, "__rmul__", __rmul__) + + def __matmul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(self._block @ other) + return NotImplemented + + def __rmatmul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(other @ self._block) + return NotImplemented + + setattr(phe_cls, "__matmul__", __matmul__) + setattr(phe_cls, "__rmatmul__", __rmatmul__) + + return phe_cls + + return PHETensorMetaclass + + +def phe_tensor_encryptor_metaclass(phe_cls, fp_cls): + class PHETensorEncryptorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_encrypt_cls = super().__new__(cls, name, bases, dict) + + def __init__(self, block_encryptor): + self._block_encryptor = block_encryptor + + def encrypt(self, tensor: fp_cls) -> phe_cls: + return phe_cls(self._block_encryptor.encrypt(tensor)) + + setattr(phe_encrypt_cls, "__init__", __init__) + setattr(phe_encrypt_cls, "encrypt", encrypt) + return phe_encrypt_cls + + return PHETensorEncryptorMetaclass + + +def phe_tensor_decryptor_metaclass(phe_cls, fp_cls): + class PHETensorDecryptorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_decrypt_cls = super().__new__(cls, name, bases, dict) + + def __init__(self, block_decryptor) -> None: + self._block_decryptor = block_decryptor + + def decrypt(self, tensor: phe_cls) -> fp_cls: + return self._block_decryptor.decrypt(tensor._block) + + setattr(phe_decrypt_cls, "__init__", __init__) + setattr(phe_decrypt_cls, "decrypt", decrypt) + return phe_decrypt_cls + + return PHETensorDecryptorMetaclass + + +def phe_tensor_cipher_metaclass( + phe_cls, + phe_encrypt_cls, + phe_decrypt_cls, + block_cipher, +): + class PHETensorCipherMetaclass(type): + def __new__(cls, name, bases, dict): + phe_cipher_cls = super().__new__(cls, name, bases, dict) + + @classmethod + def keygen(cls, **kwargs) -> typing.Tuple[phe_encrypt_cls, phe_decrypt_cls]: + block_encrytor, block_decryptor = block_cipher.keygen(**kwargs) + return ( + phe_encrypt_cls(block_encrytor), + phe_decrypt_cls(block_decryptor), + ) + + setattr(phe_cipher_cls, "keygen", keygen) + return phe_cipher_cls + + return PHETensorCipherMetaclass diff --git a/python/fate/arch/tensor/impl/tensor/distributed.py b/python/fate/arch/tensor/impl/tensor/distributed.py new file mode 100644 index 0000000000..5e9b7078c2 --- /dev/null +++ b/python/fate/arch/tensor/impl/tensor/distributed.py @@ -0,0 +1,280 @@ +import typing +from typing import Union + +import torch + +from ....abc._computing import CTableABC +from ..._federation import FederationDeserializer +from ..._tensor import Context, Party +from ...abc.tensor import ( + FPTensorProtocol, + PHECipherABC, + PHEDecryptorABC, + PHEEncryptorABC, + PHETensorABC, +) + +Numeric = typing.Union[int, float] + + +class Distributed: + @property + def blocks(self) -> CTableABC: + ... + + def is_distributed(self): + return True + + +class FPTensorDistributed(FPTensorProtocol, Distributed): + """ + Demo of Distributed Fixed Presicion Tensor + """ + + def __init__(self, blocks_table, shape=None): + """ + use table to store blocks in format (blockid, block) + """ + self._blocks_table = blocks_table + + # assuming blocks are arranged vertically + if shape is None: + shapes = list(self._blocks_table.mapValues(lambda cb: cb.shape).collect()) + self.shape = (sum(s[0] for s in shapes), shapes[0][1]) + else: + self.shape = shape + + @property + def blocks(self): + return self._blocks_table + + def _binary_op(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return FPTensorDistributed( + other._blocks_table.join( + self._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return FPTensorDistributed( + self._blocks_table.mapValues(lambda x: getattr(x, func_name)(other)) + ) + return NotImplemented + + def collect(self): + blocks = sorted(self._blocks_table.collect()) + return torch.cat([pair[1] for pair in blocks]) + + def __add__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__mul__") + + def __rmul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rmul__") + + def __matmul__(self, other: "PHETensorDistributed") -> "PHETensorDistributed": + assert self.shape[1] == other.shape[0] + # support one dimension only + assert len(other.shape) == 1 + + def func(cb): + return cb @ other._blocks_table.collect() + + self._blocks_table.mapValues() + + def __rmatmul__(self, other: "PHETensorDistributed") -> "FPTensorDistributed": + # todo: fix + ... + + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PHETensorDistributed(PHETensorABC): + def __init__(self, blocks_table, shape=None): + """ + use table to store blocks in format (blockid, encrypted_block) + """ + self._blocks_table = blocks_table + self._is_transpose = False + + # assume block is verticel aranged + if shape is None: + shapes = list(self._blocks_table.mapValues(lambda cb: cb.shape).collect()) + self.shape = (sum(s[1][0] for s in shapes), shapes[0][1][1]) + else: + self.shape = shape + + def collect(self): + blocks = sorted(self._blocks_table.collect()) + return torch.cat([pair[1] for pair in blocks]) + + def __add__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__mul__") + + def __rmul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__rmul__") + + def __matmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": + # TODO: impl me + ... + + def __rmatmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": + # TODO: impl me + ... + + def T(self) -> "PHETensorDistributed": + transposed = PHETensorDistributed(self._blocks_table) + transposed._is_transpose = not self._is_transpose + return transposed + + def serialize(self): + return self._blocks_table + + def deserialize(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getstates__(self): + return {"_is_transpose": self._is_transpose} + + def _binary_op(self, other, func_name): + if isinstance(other, (FPTensorDistributed, PHETensorDistributed)): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + + return NotImplemented + + def _binary_op_limited(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + return NotImplemented + + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key, self._is_transpose) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PaillierPHEEncryptorDistributed(PHEEncryptorABC): + def __init__(self, block_encryptor) -> None: + self._block_encryptor = block_encryptor + + def encrypt(self, tensor: FPTensorDistributed) -> PHETensorDistributed: + return PHETensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_encryptor.encrypt(x)) + ) + + +class PaillierPHEDecryptorDistributed(PHEDecryptorABC): + def __init__(self, block_decryptor) -> None: + self._block_decryptor = block_decryptor + + def decrypt(self, tensor: PHETensorDistributed) -> FPTensorDistributed: + return FPTensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_decryptor.decrypt(x)) + ) + + +class PaillierPHECipherDistributed(PHECipherABC): + @classmethod + def keygen( + cls, **kwargs + ) -> typing.Tuple[PaillierPHEEncryptorDistributed, PaillierPHEDecryptorDistributed]: + from ..blocks.cpu_paillier_block import BlockPaillierCipher + + block_encrytor, block_decryptor = BlockPaillierCipher.keygen(**kwargs) + return ( + PaillierPHEEncryptorDistributed(block_encrytor), + PaillierPHEDecryptorDistributed(block_decryptor), + ) + + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key, is_transpose) -> None: + self.table_key = self.make_frac_key(key, "table") + self.is_transpose = is_transpose + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = PHETensorDistributed(table) + tensor._is_transpose = self.is_transpose + return tensor + + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "table") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = FPTensorDistributed(table) + return tensor diff --git a/python/fate/arch/tensor/impl/tensor/multithread_cpu_tensor.py b/python/fate/arch/tensor/impl/tensor/multithread_cpu_tensor.py new file mode 100644 index 0000000000..81ec163e68 --- /dev/null +++ b/python/fate/arch/tensor/impl/tensor/multithread_cpu_tensor.py @@ -0,0 +1,38 @@ +import torch + +from ..blocks.multithread_cpu_paillier_block import BlockPaillierCipher +from ._metaclass import ( + phe_tensor_cipher_metaclass, + phe_tensor_decryptor_metaclass, + phe_tensor_encryptor_metaclass, + phe_tensor_metaclass, +) + +FPTensorLocal = torch.Tensor + + +class PHETensorLocal(metaclass=phe_tensor_metaclass(FPTensorLocal)): + ... + + +class PaillierPHEEncryptorLocal( + metaclass=phe_tensor_encryptor_metaclass(PHETensorLocal, FPTensorLocal) +): + ... + + +class PaillierPHEDecryptorLocal( + metaclass=phe_tensor_decryptor_metaclass(PHETensorLocal, FPTensorLocal) +): + ... + + +class PaillierPHECipherLocal( + metaclass=phe_tensor_cipher_metaclass( + PHETensorLocal, + PaillierPHEEncryptorLocal, + PaillierPHEDecryptorLocal, + BlockPaillierCipher, + ) +): + ... diff --git a/python/fate/arch/tensor/impl/tensor/row_distributed.py b/python/fate/arch/tensor/impl/tensor/row_distributed.py new file mode 100644 index 0000000000..6575ff6977 --- /dev/null +++ b/python/fate/arch/tensor/impl/tensor/row_distributed.py @@ -0,0 +1,236 @@ +import typing +from typing import Union + +from ..._federation import FederationDeserializer +from ..._tensor import Context, Party +from ...abc.tensor import ( + FPTensorProtocol, + PHECipherABC, + PHEDecryptorABC, + PHEEncryptorABC, + PHETensorABC, +) + +Numeric = typing.Union[int, float] + + +class FPTensorDistributed(FPTensorProtocol): + """ + Demo of Distributed Fixed Presicion Tensor + """ + + def __init__(self, blocks_table): + """ + use table to store blocks in format (blockid, block) + """ + self._blocks_table = blocks_table + + def _binary_op(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return FPTensorDistributed( + other._blocks_table.join( + self._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return FPTensorDistributed( + self._blocks_table.mapValues(lambda x: getattr(x, func_name)(other)) + ) + return NotImplemented + + def __add__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__mul__") + + def __rmul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rmul__") + + def __matmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + # todo: fix + ... + + def __rmatmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + # todo: fix + ... + + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PHETensorDistributed(PHETensorABC): + def __init__(self, blocks_table) -> None: + """ + use table to store blocks in format (blockid, encrypted_block) + """ + self._blocks_table = blocks_table + self._is_transpose = False + + def __add__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__mul__") + + def __rmul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__rmul__") + + def __matmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": + # TODO: impl me + ... + + def __rmatmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": + # TODO: impl me + ... + + def T(self) -> "PHETensorDistributed": + transposed = PHETensorDistributed(self._blocks_table) + transposed._is_transpose = not self._is_transpose + return transposed + + def serialize(self): + return self._blocks_table + + def deserialize(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getstates__(self): + return {"_is_transpose": self._is_transpose} + + def _binary_op(self, other, func_name): + if isinstance(other, (FPTensorDistributed, PHETensorDistributed)): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + + return NotImplemented + + def _binary_op_limited(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + return NotImplemented + + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key, self._is_transpose) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PaillierPHEEncryptorDistributed(PHEEncryptorABC): + def __init__(self, block_encryptor) -> None: + self._block_encryptor = block_encryptor + + def encrypt(self, tensor: FPTensorDistributed) -> PHETensorDistributed: + return PHETensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_encryptor.encrypt(x)) + ) + + +class PaillierPHEDecryptorDistributed(PHEDecryptorABC): + def __init__(self, block_decryptor) -> None: + self._block_decryptor = block_decryptor + + def decrypt(self, tensor: PHETensorDistributed) -> FPTensorDistributed: + return FPTensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_decryptor.decrypt(x)) + ) + + +class PaillierPHECipherDistributed(PHECipherABC): + @classmethod + def keygen( + cls, **kwargs + ) -> typing.Tuple[PaillierPHEEncryptorDistributed, PaillierPHEDecryptorDistributed]: + from ..blocks.cpu_paillier_block import BlockPaillierCipher + + block_encrytor, block_decryptor = BlockPaillierCipher.keygen(**kwargs) + return ( + PaillierPHEEncryptorDistributed(block_encrytor), + PaillierPHEDecryptorDistributed(block_decryptor), + ) + + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key, is_transpose) -> None: + self.table_key = self.make_frac_key(key, "table") + self.is_transpose = is_transpose + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = PHETensorDistributed(table) + tensor._is_transpose = self.is_transpose + return tensor + + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "table") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = FPTensorDistributed(table) + return tensor diff --git a/python/fate/arch/tensor/ops/__init__.py b/python/fate/arch/tensor/ops/__init__.py new file mode 100644 index 0000000000..1d6a63fd0f --- /dev/null +++ b/python/fate/arch/tensor/ops/__init__.py @@ -0,0 +1,2 @@ +def broadcast_matmul(matrix, bc_matrix): + return matrix.blocks.mapValues(lambda cb: cb @ bc_matrix)