diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index f1d490a664..4309aaf655 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -224,6 +224,7 @@ def _get_parties(self, role: Optional[Literal["guest", "host", "arbiter"]] = Non parties.extend(self._role_to_parties[role]) parties.sort(key=lambda x: x[0]) return Parties( + self, self._get_federation(), parties, self._namespace, diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index aeb957325d..c85ca7f2e0 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -14,17 +14,19 @@ # limitations under the License. import io import pickle +import struct +import typing from typing import Any, List, Tuple, TypeVar, Union from fate.arch.abc import FederationEngine, PartyMeta - +from ._namespace import NS from ..computing import is_table from ..federation._gc import IterationGC -from ._namespace import NS T = TypeVar("T") -import struct +if typing.TYPE_CHECKING: + from fate.arch.context import Context class GC: @@ -45,27 +47,31 @@ def get_or_set_pull_gc(self, key): class _KeyedParty: def __init__(self, party: Union["Party", "Parties"], key) -> None: - self.party = party - self.key = key + self._party = party + self._key = key def put(self, value): - return self.party.put(self.key, value) + return self._party.put(self._key, value) def get(self): - return self.party.get(self.key) + return self._party.get(self._key) class Party: - def __init__(self, federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None: + def __init__(self, ctx: "Context", federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None: + self._ctx = ctx + self._party = party self.federation = federation - self.party = party self.rank = rank self.namespace = namespace - self.key = key def __call__(self, key: str) -> "_KeyedParty": return _KeyedParty(self, key) + @property + def party(self) -> PartyMeta: + return self._party + @property def role(self) -> str: return self.party[0] @@ -86,7 +92,7 @@ def put(self, *args, **kwargs): return _push(self.federation, k, self.namespace, [self.party], v) def get(self, name: str): - return _pull(self.federation, name, self.namespace, [self.party])[0] + return _pull(self._ctx, self.federation, name, self.namespace, [self.party])[0] def get_int(self, name: str): ... @@ -94,11 +100,13 @@ def get_int(self, name: str): class Parties: def __init__( - self, - federation: FederationEngine, - parties: List[Tuple[int, PartyMeta]], - namespace: NS, + self, + ctx: "Context", + federation: FederationEngine, + parties: List[Tuple[int, PartyMeta]], + namespace: NS, ) -> None: + self._ctx = ctx self.federation = federation self.parties = parties self.namespace = namespace @@ -109,10 +117,10 @@ def ranks(self): def __getitem__(self, key: int) -> Party: rank, party = self.parties[key] - return Party(self.federation, party, rank, self.namespace) + return Party(self._ctx, self.federation, party, rank, self.namespace) def __iter__(self): - return iter([Party(self.federation, party, rank, self.namespace) for rank, party in self.parties]) + return iter([Party(self._ctx, self.federation, party, rank, self.namespace) for rank, party in self.parties]) def __len__(self) -> int: return len(self.parties) @@ -131,15 +139,15 @@ def put(self, *args, **kwargs): return _push(self.federation, k, self.namespace, [p[1] for p in self.parties], v) def get(self, name: str): - return _pull(self.federation, name, self.namespace, [p[1] for p in self.parties]) + return _pull(self._ctx, self.federation, name, self.namespace, [p[1] for p in self.parties]) def _push( - federation: FederationEngine, - name: str, - namespace: NS, - parties: List[PartyMeta], - value, + federation: FederationEngine, + name: str, + namespace: NS, + parties: List[PartyMeta], + value, ): tag = namespace.federation_tag _TableRemotePersistentPickler.push(value, federation, name, tag, parties) @@ -162,7 +170,7 @@ def encode_str(cls, value: str) -> bytes: @classmethod def decode_str(cls, value: bytes) -> str: length = struct.unpack("!I", value[:4])[0] # get length of string - return value[4 : 4 + length].decode("utf-8") # decode string + return value[4: 4 + length].decode("utf-8") # decode string @classmethod def encode_bytes(cls, value: bytes) -> bytes: @@ -171,7 +179,7 @@ def encode_bytes(cls, value: bytes) -> bytes: @classmethod def decode_bytes(cls, value: bytes) -> bytes: length = struct.unpack("!I", value[:4])[0] # get length of bytes - return value[4 : 4 + length] # extract bytes + return value[4: 4 + length] # extract bytes @classmethod def encode_float(cls, value: float) -> bytes: @@ -184,14 +192,15 @@ def decode_float(cls, value: bytes) -> float: def _push_int(federation: FederationEngine, name: str, namespace: NS, parties: List[PartyMeta], value: int): tag = namespace.federation_tag - federation.push(v=f.getvalue(), name=name, tag=tag, parties=parties) + federation.push(v=Serde.encode_int(value), name=name, tag=tag, parties=parties) def _pull( - federation: FederationEngine, - name: str, - namespace: NS, - parties: List[PartyMeta], + ctx: "Context", + federation: FederationEngine, + name: str, + namespace: NS, + parties: List[PartyMeta], ): tag = namespace.federation_tag raw_values = federation.pull( @@ -201,7 +210,7 @@ def _pull( ) values = [] for party, buffers in zip(parties, raw_values): - values.append(_TableRmotePersistentUnpickler.pull(buffers, federation, name, tag, party)) + values.append(_TableRemotePersistentUnpickler.pull(buffers, ctx, federation, name, tag, party)) return values @@ -210,14 +219,19 @@ def __init__(self, key) -> None: self.key = key +class _ContextPersistentId: + def __init__(self, key) -> None: + self.key = key + + class _TableRemotePersistentPickler(pickle.Pickler): def __init__( - self, - federation: FederationEngine, - name: str, - tag: str, - parties: List[PartyMeta], - f, + self, + federation: FederationEngine, + name: str, + tag: str, + parties: List[PartyMeta], + f, ) -> None: self._federation = federation self._name = name @@ -233,20 +247,24 @@ def _get_next_table_key(self): return f"{self._name}__table_persistent_{self._table_index}__" def persistent_id(self, obj: Any) -> Any: + from fate.arch.context import Context if is_table(obj): key = self._get_next_table_key() self._federation.push(v=obj, name=key, tag=self._tag, parties=self._parties) self._table_index += 1 return _TablePersistentId(key) + if isinstance(obj, Context): + key = f"{self._name}__context__" + return _ContextPersistentId(key) @classmethod def push( - cls, - value, - federation: FederationEngine, - name: str, - tag: str, - parties: List[PartyMeta], + cls, + value, + federation: FederationEngine, + name: str, + tag: str, + parties: List[PartyMeta], ): with io.BytesIO() as f: pickler = _TableRemotePersistentPickler(federation, name, tag, parties, f) @@ -254,15 +272,17 @@ def push( federation.push(v=f.getvalue(), name=name, tag=tag, parties=parties) -class _TableRmotePersistentUnpickler(pickle.Unpickler): +class _TableRemotePersistentUnpickler(pickle.Unpickler): def __init__( - self, - federation: FederationEngine, - name: str, - tag: str, - party: PartyMeta, - f, + self, + ctx: "Context", + federation: FederationEngine, + name: str, + tag: str, + party: PartyMeta, + f, ): + self._ctx = ctx self._federation = federation self._name = name self._tag = tag @@ -273,16 +293,19 @@ def persistent_load(self, pid: Any) -> Any: if isinstance(pid, _TablePersistentId): table = self._federation.pull(pid.key, self._tag, [self._party])[0] return table + if isinstance(pid, _ContextPersistentId): + return self._ctx @classmethod def pull( - cls, - buffers, - federation: FederationEngine, - name: str, - tag: str, - party: PartyMeta, + cls, + buffers, + ctx: "Context", + federation: FederationEngine, + name: str, + tag: str, + party: PartyMeta, ): with io.BytesIO(buffers) as f: - unpickler = _TableRmotePersistentUnpickler(federation, name, tag, party, f) + unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f) return unpickler.load()