Skip to content

Commit

Permalink
feat: redesign Tensor and add ops
Browse files Browse the repository at this point in the history
Signed-off-by: sage0615 <sagewb@outlook.com>
  • Loading branch information
sage0615 committed Oct 13, 2022
1 parent ed30121 commit eccc30c
Show file tree
Hide file tree
Showing 32 changed files with 1,215 additions and 330 deletions.
File renamed without changes.
Empty file added python/fate/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions python/fate/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#

from .context import Context
from .unify import Backend, Device
from .unify import Backend, device

__all__ = ["Backend", "Device", "Context"]
__all__ = ["Backend", "device", "Context"]
5 changes: 4 additions & 1 deletion python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def namespace(self):

def __del__(self):
if self._need_cleanup:
self.destroy()
try:
self.destroy()
except:
pass

def __str__(self):
return f"<Table {self._namespace}|{self._name}|{self._partitions}|{self._need_cleanup}>"
Expand Down
54 changes: 3 additions & 51 deletions python/fate/arch/context/_cipher.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,14 @@
from enum import Enum
from typing import Tuple

from fate.interface import CipherKit as CipherKitInterface
from fate.interface import PHECipher as PHECipherInterface

from ..tensor._tensor import PHEDecryptor, PHEEncryptor
from ..unify import Backend, Device
from ..tensor._phe import PHECipher
from ..unify import Backend, device


class CipherKit(CipherKitInterface):
def __init__(self, backend: Backend, device: Device) -> None:
def __init__(self, backend: Backend, device: device) -> None:
self.backend = backend
self.device = device

@property
def phe(self):
return PHECipher(self.backend, self.device)


class PHEKind(Enum):
AUTO = "auto"
PAILLIER = "Paillier"
RUST_PAILLIER = "rust_paillier"
INTEL_PAILLIER = "intel_paillier"


class PHECipher(PHECipherInterface):
def __init__(self, backend: Backend, device: Device) -> None:
self.backend = backend
self.device = device

def keygen(
self, kind: PHEKind = PHEKind.AUTO, options={}
) -> Tuple["PHEEncryptor", "PHEDecryptor"]:
if kind == PHEKind.AUTO or PHEKind.PAILLIER:

if self.backend == Backend.LOCAL:
from ..tensor.impl.tensor.multithread_cpu_tensor import (
PaillierPHECipherLocal,
)

key_length = options.get("key_length", 1024)
encryptor, decryptor = PaillierPHECipherLocal().keygen(
key_length=key_length
)
return PHEEncryptor(encryptor), PHEDecryptor(decryptor)

if self.backend in {Backend.STANDALONE, Backend.SPARK, Backend.EGGROLL}:
from ..tensor.impl.tensor.distributed import (
PaillierPHECipherDistributed,
)

key_length = options.get("key_length", 1024)
encryptor, decryptor = PaillierPHECipherDistributed().keygen(
key_length=key_length
)
return PHEEncryptor(encryptor), PHEDecryptor(decryptor)

raise NotImplementedError(
f"keygen for kind<{kind}>-distributed<{self.backend}>-device<{self.device}> is not implemented"
)
17 changes: 15 additions & 2 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from fate.interface import MetricMeta as MetricMetaInterface
from fate.interface import Metrics, PartyMeta, Summary

from ..unify import Backend, Device
from ..unify import Backend, device
from ._cipher import CipherKit
from ._federation import GC, Parties, Party
from ._io import ReadKit, WriteKit
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
self,
context_name: Optional[str] = None,
backend: Backend = Backend.LOCAL,
device: Device = Device.CPU,
device: device = device.CPU,
computing: Optional[ComputingEngine] = None,
federation: Optional[FederationEngine] = None,
summary: Summary = DummySummary(),
Expand Down Expand Up @@ -199,6 +199,14 @@ def __init__(

self._gc = GC()

@property
def computing(self):
return self._get_computing()

@property
def federation(self):
return self._get_federation()

@contextmanager
def sub_ctx(self, namespace: str) -> Iterator["Context"]:
with self.namespace.into_subnamespace(namespace):
Expand Down Expand Up @@ -266,3 +274,8 @@ def _get_federation(self):
if self._federation is None:
raise RuntimeError(f"federation not set")
return self._federation

def _get_computing(self):
if self._computing is None:
raise RuntimeError(f"computing not set")
return self._computing
6 changes: 3 additions & 3 deletions python/fate/arch/context/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch

from ..tensor import FPTensor
from ..unify import Backend, Device
from ..tensor import Tensor as FPTensor
from ..unify import Backend, device


class TensorKit:
def __init__(self, computing, backend: Backend, device: Device) -> None:
def __init__(self, computing, backend: Backend, device: device) -> None:
self.computing = computing
self.backend = backend
self.device = device
Expand Down
6 changes: 2 additions & 4 deletions python/fate/arch/federation/eggroll/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

self.get_gc: GarbageCollector = GarbageCollector()
self.remote_gc: GarbageCollector = GarbageCollector()
self.party = party
self.local_party = party
self.parties = parties
self._rsc = RollSiteContext(rs_session_id, rp_ctx=rp_ctx, options=options)
LOGGER.debug(f"[federation.eggroll]init federation context done")
Expand Down Expand Up @@ -131,9 +131,7 @@ def _push_with_exception_handle(rsc, v, name: str, tag: str, parties: List[Party
def _remote_exception_re_raise(f, p: PartyMeta):
try:
f.result()
LOGGER.debug(
f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} done"
)
LOGGER.debug(f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} done")
except Exception as e:
pid = os.getpid()
LOGGER.exception(
Expand Down
10 changes: 5 additions & 5 deletions python/fate/arch/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._tensor import FPTensor, PHETensor
from ._tensor import tensor, distributed_tensor
from ._ops import *
from ._unary_ops import *
from ._binary_ops import *

__all__ = [
"FPTensor",
"PHETensor",
]
__all__ = ["tensor", "distributed_tensor"]
181 changes: 181 additions & 0 deletions python/fate/arch/tensor/__tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import Any, List, Union, overload

from fate.interface import FederationDeserializer as FederationDeserializerInterface
from fate.interface import FederationEngine, PartyMeta

from .abc.tensor import PHEDecryptorABC, PHEEncryptorABC, PHETensorABC


class FPTensor:
def __init__(self, tensor) -> None:
self._tensor = tensor

@property
def shape(self):
return self._tensor.shape

def __add__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__add__"):
return NotImplemented
return self._binary_op(other, self._tensor.__add__)

def __radd__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__radd__"):
return self.__add__(other)
return self._binary_op(other, self._tensor.__add__)

def __sub__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__sub__"):
return NotImplemented
return self._binary_op(other, self._tensor.__sub__)

def __rsub__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__rsub__"):
return self.__mul__(-1).__add__(other)
return self._binary_op(other, self._tensor.__rsub__)

def __mul__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__mul__"):
return NotImplemented
return self._binary_op(other, self._tensor.__mul__)

def __rmul__(self, other: Union["FPTensor", float, int]) -> "FPTensor":
if not hasattr(self._tensor, "__rmul__"):
return self.__mul__(other)
return self._binary_op(other, self._tensor.__rmul__)

def __matmul__(self, other: "FPTensor") -> "FPTensor":
if not hasattr(self._tensor, "__matmul__"):
return NotImplemented
if isinstance(other, FPTensor):
return FPTensor(self._tensor.__matmul__(other._tensor))
else:
return NotImplemented

def __rmatmul__(self, other: "FPTensor") -> "FPTensor":
if not hasattr(self._tensor, "__rmatmul__"):
return NotImplemented
if isinstance(other, FPTensor):
return FPTensor(self._tensor.__rmatmul__(other._tensor))
else:
return NotImplemented

def _binary_op(self, other, func):
if isinstance(other, FPTensor):
return FPTensor(func(other._tensor))
elif isinstance(other, (int, float)):
return FPTensor(func(other))
else:
return NotImplemented

@property
def T(self):
return FPTensor(self._tensor.T)

def __federation_hook__(
self,
federation: FederationEngine,
name: str,
tag: str,
parties: List[PartyMeta],
):
deserializer = FPTensorFederationDeserializer(name)
# 1. remote deserializer with objs
federation.push(deserializer, name, tag, parties)
# 2. remote table
federation.push(self._tensor, deserializer.table_key, tag, parties)


class PHETensor:
def __init__(self, tensor: PHETensorABC) -> None:
self._tensor = tensor

@property
def shape(self):
return self._tensor.shape

def __add__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__add__)

def __radd__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__radd__)

def __sub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__sub__)

def __rsub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__rsub__)

def __mul__(self, other: Union[FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__mul__)

def __rmul__(self, other: Union[FPTensor, int, float]) -> "PHETensor":
return self._binary_op(other, self._tensor.__rmul__)

def __matmul__(self, other: FPTensor) -> "PHETensor":
if isinstance(other, FPTensor):
return PHETensor(self._tensor.__matmul__(other._tensor))
else:
return NotImplemented

def __rmatmul__(self, other: FPTensor) -> "PHETensor":
if isinstance(other, FPTensor):
return PHETensor(self._tensor.__rmatmul__(other._tensor))
else:
return NotImplemented

def T(self) -> "PHETensor":
return PHETensor(self._tensor.T())

@overload
def decrypt(self, decryptor: "PHEDecryptor") -> FPTensor:
...

@overload
def decrypt(self, decryptor) -> Any:
...

def decrypt(self, decryptor):
return decryptor.decrypt(self)

def _binary_op(self, other, func):
if isinstance(other, (PHETensor, FPTensor)):
return PHETensor(func(other._tensor))
elif isinstance(other, (int, float)):
return PHETensor(func(other))
return NotImplemented

def __federation_hook__(self, ctx, key, parties):
deserializer = PHETensorFederationDeserializer(key)
# 1. remote deserializer with objs
ctx._push(parties, key, deserializer)
# 2. remote table
ctx._push(parties, deserializer.table_key, self._tensor)


class PHETensorFederationDeserializer(FederationDeserializerInterface):
def __init__(self, key) -> None:
self.table_key = f"__phetensor_{key}__"

def __do_deserialize__(
self,
federation: FederationEngine,
tag: str,
party: PartyMeta,
) -> PHETensor:
tensor = federation.pull(name=self.table_key, tag=tag, parties=[party])[0]
return PHETensor(tensor)


class FPTensorFederationDeserializer(FederationDeserializerInterface):
def __init__(self, key) -> None:
self.table_key = f"__tensor_{key}__"

def __do_deserialize__(
self,
federation: FederationEngine,
tag: str,
party: PartyMeta,
) -> FPTensor:
tensor = federation.pull(name=self.table_key, tag=tag, parties=[party])[0]
return FPTensor(tensor)
Loading

0 comments on commit eccc30c

Please sign in to comment.