Skip to content

Commit

Permalink
add detach to encrypt
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jul 5, 2023
1 parent d96fe50 commit cdca4cb
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 86 deletions.
2 changes: 1 addition & 1 deletion python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def destroy(self):
self._session.cleanup(namespace=self._session_id, name="*")

# noinspection PyUnusedLocal
def remote(self, v, name: str, tag: str, parties: List[Tuple[str, str]]):
def remote(self, v, name: str, tag: str, parties: List[PartyMeta]):
log_str = f"federation.standalone.remote.{name}.{tag}"

if v is None:
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._federation import FederationEngine, GarbageCollector
from ._party import Parties, Party, PartyMeta
from ._party import PartyMeta
from ._table import CSessionABC, CTableABC
9 changes: 1 addition & 8 deletions python/fate/arch/abc/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from typing import List, Optional, Protocol

from ._party import Parties, Party, PartyMeta
from ._party import PartyMeta


class GarbageCollector(Protocol):
Expand Down Expand Up @@ -46,10 +46,3 @@ def push(

def destroy(self):
...


class FederationWrapper(Protocol):
guest: Party
hosts: Parties
arbiter: Party
parties: Parties
57 changes: 1 addition & 56 deletions python/fate/arch/abc/_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Literal, Optional, Protocol, Tuple, TypeVar, overload

T = TypeVar("T")


class _KeyedParty(Protocol):
def put(self, value):
...

def get(self) -> Any:
...


class Party(Protocol):
def get(self, name: str) -> Any:
...

@overload
def put(self, name: str, value):
...

@overload
def put(self, **kwargs):
...

def __call__(self, key: str) -> _KeyedParty:
...


class Parties(Protocol):
def get(self, name: str) -> List:
...

@overload
def put(self, name: str, value):
...

@overload
def put(self, **kwargs):
...

def __getitem__(self, key: int) -> Party:
...

def get_neighbor(self, shift: int, module: bool = False) -> Party:
...

def get_neighbors(self) -> "Parties":
...

def get_local_index(self) -> Optional[int]:
...

def __call__(self, key: str) -> _KeyedParty:
...

from typing import Literal, Tuple

PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str]
17 changes: 7 additions & 10 deletions python/fate/arch/context/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import pickle
from typing import Any, List, Optional, TypeVar, Union

from fate.arch.abc import FederationEngine
from fate.arch.abc import Parties as PartiesInterface
from fate.arch.abc import Party as PartyInterface
from fate.arch.abc import PartyMeta
from fate.arch.abc import FederationEngine, PartyMeta

from ..computing import is_table
from ..federation._gc import IterationGC
Expand Down Expand Up @@ -56,7 +53,7 @@ def get(self):
return self.party.get(self.key)


class Party(PartyInterface):
class Party:
def __init__(self, federation, party: PartyMeta, namespace: NS, key=None) -> None:
self.federation = federation
self.party = party
Expand All @@ -81,7 +78,7 @@ def get(self, name: str):
return _pull(self.federation, name, self.namespace, [self.party])[0]


class Parties(PartiesInterface):
class Parties:
def __init__(
self,
federation: FederationEngine,
Expand Down Expand Up @@ -129,7 +126,7 @@ def get_local_index(self) -> Optional[int]:
def put(self, *args, **kwargs):
if args:
assert len(args) == 2 and isinstance(args[0], str), "invalid position parameter"
assert not kwargs, "keywords paramters not allowed when position parameter provided"
assert not kwargs, "keywords parameters not allowed when position parameter provided"
kvs = [args]
else:
kvs = kwargs.items()
Expand Down Expand Up @@ -169,7 +166,7 @@ def _pull(
return values


class _TablePersistantId:
class _TablePersistentId:
def __init__(self, key) -> None:
self.key = key

Expand Down Expand Up @@ -201,7 +198,7 @@ def persistent_id(self, obj: Any) -> Any:
key = self._get_next_table_key()
self._federation.push(v=obj, name=key, tag=self._tag, parties=self._parties)
self._table_index += 1
return _TablePersistantId(key)
return _TablePersistentId(key)

@classmethod
def push(
Expand Down Expand Up @@ -234,7 +231,7 @@ def __init__(
super().__init__(f)

def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _TablePersistantId):
if isinstance(pid, _TablePersistentId):
table = self._federation.pull(pid.key, self._tag, [self._party])[0]
return table

Expand Down
7 changes: 0 additions & 7 deletions python/fate/arch/federation/_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
#


class FederationEngine(object):
EGGROLL = "EGGROLL"
RABBITMQ = "RABBITMQ"
STANDALONE = "STANDALONE"
PULSAR = "PULSAR"


class FederationDataType(object):
OBJECT = "obj"
TABLE = "Table"
Expand Down
6 changes: 3 additions & 3 deletions python/fate/arch/tensor/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def encrypt_f(tensor, encryptor):
if isinstance(tensor, torch.Tensor):
return encryptor.encrypt(tensor)
return encryptor.encrypt(tensor.detach())
else:
# torch tensor-like
if hasattr(tensor, "__torch_function__"):
Expand All @@ -12,8 +12,8 @@ def encrypt_f(tensor, encryptor):


def decrypt_f(tensor, decryptor):
if isinstance(tensor, torch.Tensor):
return decryptor.encrypt(tensor)
if isinstance(tensor, torch.Tensor.detach):
return decryptor.encrypt(tensor.detach())
else:
# torch tensor-like
if hasattr(tensor, "__torch_function__"):
Expand Down

0 comments on commit cdca4cb

Please sign in to comment.