-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: sage0615 <sagewb@outlook.com>
- Loading branch information
Showing
32 changed files
with
1,215 additions
and
330 deletions.
There are no files selected for viewing
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,14 @@ | ||
from enum import Enum | ||
from typing import Tuple | ||
|
||
from fate.interface import CipherKit as CipherKitInterface | ||
from fate.interface import PHECipher as PHECipherInterface | ||
|
||
from ..tensor._tensor import PHEDecryptor, PHEEncryptor | ||
from ..unify import Backend, Device | ||
from ..tensor._phe import PHECipher | ||
from ..unify import Backend, device | ||
|
||
|
||
class CipherKit(CipherKitInterface): | ||
def __init__(self, backend: Backend, device: Device) -> None: | ||
def __init__(self, backend: Backend, device: device) -> None: | ||
self.backend = backend | ||
self.device = device | ||
|
||
@property | ||
def phe(self): | ||
return PHECipher(self.backend, self.device) | ||
|
||
|
||
class PHEKind(Enum): | ||
AUTO = "auto" | ||
PAILLIER = "Paillier" | ||
RUST_PAILLIER = "rust_paillier" | ||
INTEL_PAILLIER = "intel_paillier" | ||
|
||
|
||
class PHECipher(PHECipherInterface): | ||
def __init__(self, backend: Backend, device: Device) -> None: | ||
self.backend = backend | ||
self.device = device | ||
|
||
def keygen( | ||
self, kind: PHEKind = PHEKind.AUTO, options={} | ||
) -> Tuple["PHEEncryptor", "PHEDecryptor"]: | ||
if kind == PHEKind.AUTO or PHEKind.PAILLIER: | ||
|
||
if self.backend == Backend.LOCAL: | ||
from ..tensor.impl.tensor.multithread_cpu_tensor import ( | ||
PaillierPHECipherLocal, | ||
) | ||
|
||
key_length = options.get("key_length", 1024) | ||
encryptor, decryptor = PaillierPHECipherLocal().keygen( | ||
key_length=key_length | ||
) | ||
return PHEEncryptor(encryptor), PHEDecryptor(decryptor) | ||
|
||
if self.backend in {Backend.STANDALONE, Backend.SPARK, Backend.EGGROLL}: | ||
from ..tensor.impl.tensor.distributed import ( | ||
PaillierPHECipherDistributed, | ||
) | ||
|
||
key_length = options.get("key_length", 1024) | ||
encryptor, decryptor = PaillierPHECipherDistributed().keygen( | ||
key_length=key_length | ||
) | ||
return PHEEncryptor(encryptor), PHEDecryptor(decryptor) | ||
|
||
raise NotImplementedError( | ||
f"keygen for kind<{kind}>-distributed<{self.backend}>-device<{self.device}> is not implemented" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from ._tensor import FPTensor, PHETensor | ||
from ._tensor import tensor, distributed_tensor | ||
from ._ops import * | ||
from ._unary_ops import * | ||
from ._binary_ops import * | ||
|
||
__all__ = [ | ||
"FPTensor", | ||
"PHETensor", | ||
] | ||
__all__ = ["tensor", "distributed_tensor"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from typing import Any, List, Union, overload | ||
|
||
from fate.interface import FederationDeserializer as FederationDeserializerInterface | ||
from fate.interface import FederationEngine, PartyMeta | ||
|
||
from .abc.tensor import PHEDecryptorABC, PHEEncryptorABC, PHETensorABC | ||
|
||
|
||
class FPTensor: | ||
def __init__(self, tensor) -> None: | ||
self._tensor = tensor | ||
|
||
@property | ||
def shape(self): | ||
return self._tensor.shape | ||
|
||
def __add__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__add__"): | ||
return NotImplemented | ||
return self._binary_op(other, self._tensor.__add__) | ||
|
||
def __radd__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__radd__"): | ||
return self.__add__(other) | ||
return self._binary_op(other, self._tensor.__add__) | ||
|
||
def __sub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__sub__"): | ||
return NotImplemented | ||
return self._binary_op(other, self._tensor.__sub__) | ||
|
||
def __rsub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__rsub__"): | ||
return self.__mul__(-1).__add__(other) | ||
return self._binary_op(other, self._tensor.__rsub__) | ||
|
||
def __mul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__mul__"): | ||
return NotImplemented | ||
return self._binary_op(other, self._tensor.__mul__) | ||
|
||
def __rmul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": | ||
if not hasattr(self._tensor, "__rmul__"): | ||
return self.__mul__(other) | ||
return self._binary_op(other, self._tensor.__rmul__) | ||
|
||
def __matmul__(self, other: "FPTensor") -> "FPTensor": | ||
if not hasattr(self._tensor, "__matmul__"): | ||
return NotImplemented | ||
if isinstance(other, FPTensor): | ||
return FPTensor(self._tensor.__matmul__(other._tensor)) | ||
else: | ||
return NotImplemented | ||
|
||
def __rmatmul__(self, other: "FPTensor") -> "FPTensor": | ||
if not hasattr(self._tensor, "__rmatmul__"): | ||
return NotImplemented | ||
if isinstance(other, FPTensor): | ||
return FPTensor(self._tensor.__rmatmul__(other._tensor)) | ||
else: | ||
return NotImplemented | ||
|
||
def _binary_op(self, other, func): | ||
if isinstance(other, FPTensor): | ||
return FPTensor(func(other._tensor)) | ||
elif isinstance(other, (int, float)): | ||
return FPTensor(func(other)) | ||
else: | ||
return NotImplemented | ||
|
||
@property | ||
def T(self): | ||
return FPTensor(self._tensor.T) | ||
|
||
def __federation_hook__( | ||
self, | ||
federation: FederationEngine, | ||
name: str, | ||
tag: str, | ||
parties: List[PartyMeta], | ||
): | ||
deserializer = FPTensorFederationDeserializer(name) | ||
# 1. remote deserializer with objs | ||
federation.push(deserializer, name, tag, parties) | ||
# 2. remote table | ||
federation.push(self._tensor, deserializer.table_key, tag, parties) | ||
|
||
|
||
class PHETensor: | ||
def __init__(self, tensor: PHETensorABC) -> None: | ||
self._tensor = tensor | ||
|
||
@property | ||
def shape(self): | ||
return self._tensor.shape | ||
|
||
def __add__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__add__) | ||
|
||
def __radd__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__radd__) | ||
|
||
def __sub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__sub__) | ||
|
||
def __rsub__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__rsub__) | ||
|
||
def __mul__(self, other: Union[FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__mul__) | ||
|
||
def __rmul__(self, other: Union[FPTensor, int, float]) -> "PHETensor": | ||
return self._binary_op(other, self._tensor.__rmul__) | ||
|
||
def __matmul__(self, other: FPTensor) -> "PHETensor": | ||
if isinstance(other, FPTensor): | ||
return PHETensor(self._tensor.__matmul__(other._tensor)) | ||
else: | ||
return NotImplemented | ||
|
||
def __rmatmul__(self, other: FPTensor) -> "PHETensor": | ||
if isinstance(other, FPTensor): | ||
return PHETensor(self._tensor.__rmatmul__(other._tensor)) | ||
else: | ||
return NotImplemented | ||
|
||
def T(self) -> "PHETensor": | ||
return PHETensor(self._tensor.T()) | ||
|
||
@overload | ||
def decrypt(self, decryptor: "PHEDecryptor") -> FPTensor: | ||
... | ||
|
||
@overload | ||
def decrypt(self, decryptor) -> Any: | ||
... | ||
|
||
def decrypt(self, decryptor): | ||
return decryptor.decrypt(self) | ||
|
||
def _binary_op(self, other, func): | ||
if isinstance(other, (PHETensor, FPTensor)): | ||
return PHETensor(func(other._tensor)) | ||
elif isinstance(other, (int, float)): | ||
return PHETensor(func(other)) | ||
return NotImplemented | ||
|
||
def __federation_hook__(self, ctx, key, parties): | ||
deserializer = PHETensorFederationDeserializer(key) | ||
# 1. remote deserializer with objs | ||
ctx._push(parties, key, deserializer) | ||
# 2. remote table | ||
ctx._push(parties, deserializer.table_key, self._tensor) | ||
|
||
|
||
class PHETensorFederationDeserializer(FederationDeserializerInterface): | ||
def __init__(self, key) -> None: | ||
self.table_key = f"__phetensor_{key}__" | ||
|
||
def __do_deserialize__( | ||
self, | ||
federation: FederationEngine, | ||
tag: str, | ||
party: PartyMeta, | ||
) -> PHETensor: | ||
tensor = federation.pull(name=self.table_key, tag=tag, parties=[party])[0] | ||
return PHETensor(tensor) | ||
|
||
|
||
class FPTensorFederationDeserializer(FederationDeserializerInterface): | ||
def __init__(self, key) -> None: | ||
self.table_key = f"__tensor_{key}__" | ||
|
||
def __do_deserialize__( | ||
self, | ||
federation: FederationEngine, | ||
tag: str, | ||
party: PartyMeta, | ||
) -> FPTensor: | ||
tensor = federation.pull(name=self.table_key, tag=tag, parties=[party])[0] | ||
return FPTensor(tensor) |
Oops, something went wrong.