Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serde support ctx #5020

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _get_parties(self, role: Optional[Literal["guest", "host", "arbiter"]] = Non
parties.extend(self._role_to_parties[role])
parties.sort(key=lambda x: x[0])
return Parties(
self,
self._get_federation(),
parties,
self._namespace,
Expand Down
137 changes: 80 additions & 57 deletions python/fate/arch/context/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# limitations under the License.
import io
import pickle
import struct
import typing
from typing import Any, List, Tuple, TypeVar, Union

from fate.arch.abc import FederationEngine, PartyMeta

from ._namespace import NS
from ..computing import is_table
from ..federation._gc import IterationGC
from ._namespace import NS

T = TypeVar("T")

import struct
if typing.TYPE_CHECKING:
from fate.arch.context import Context


class GC:
Expand All @@ -45,27 +47,31 @@ def get_or_set_pull_gc(self, key):

class _KeyedParty:
def __init__(self, party: Union["Party", "Parties"], key) -> None:
self.party = party
self.key = key
self._party = party
self._key = key

def put(self, value):
return self.party.put(self.key, value)
return self._party.put(self._key, value)

def get(self):
return self.party.get(self.key)
return self._party.get(self._key)


class Party:
def __init__(self, federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None:
def __init__(self, ctx: "Context", federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None:
self._ctx = ctx
self._party = party
self.federation = federation
self.party = party
self.rank = rank
self.namespace = namespace
self.key = key

def __call__(self, key: str) -> "_KeyedParty":
return _KeyedParty(self, key)

@property
def party(self) -> PartyMeta:
return self._party

@property
def role(self) -> str:
return self.party[0]
Expand All @@ -86,19 +92,21 @@ def put(self, *args, **kwargs):
return _push(self.federation, k, self.namespace, [self.party], v)

def get(self, name: str):
return _pull(self.federation, name, self.namespace, [self.party])[0]
return _pull(self._ctx, self.federation, name, self.namespace, [self.party])[0]

def get_int(self, name: str):
...


class Parties:
def __init__(
self,
federation: FederationEngine,
parties: List[Tuple[int, PartyMeta]],
namespace: NS,
self,
ctx: "Context",
federation: FederationEngine,
parties: List[Tuple[int, PartyMeta]],
namespace: NS,
) -> None:
self._ctx = ctx
self.federation = federation
self.parties = parties
self.namespace = namespace
Expand All @@ -109,10 +117,10 @@ def ranks(self):

def __getitem__(self, key: int) -> Party:
rank, party = self.parties[key]
return Party(self.federation, party, rank, self.namespace)
return Party(self._ctx, self.federation, party, rank, self.namespace)

def __iter__(self):
return iter([Party(self.federation, party, rank, self.namespace) for rank, party in self.parties])
return iter([Party(self._ctx, self.federation, party, rank, self.namespace) for rank, party in self.parties])

def __len__(self) -> int:
return len(self.parties)
Expand All @@ -131,15 +139,15 @@ def put(self, *args, **kwargs):
return _push(self.federation, k, self.namespace, [p[1] for p in self.parties], v)

def get(self, name: str):
return _pull(self.federation, name, self.namespace, [p[1] for p in self.parties])
return _pull(self._ctx, self.federation, name, self.namespace, [p[1] for p in self.parties])


def _push(
federation: FederationEngine,
name: str,
namespace: NS,
parties: List[PartyMeta],
value,
federation: FederationEngine,
name: str,
namespace: NS,
parties: List[PartyMeta],
value,
):
tag = namespace.federation_tag
_TableRemotePersistentPickler.push(value, federation, name, tag, parties)
Expand All @@ -162,7 +170,7 @@ def encode_str(cls, value: str) -> bytes:
@classmethod
def decode_str(cls, value: bytes) -> str:
length = struct.unpack("!I", value[:4])[0] # get length of string
return value[4 : 4 + length].decode("utf-8") # decode string
return value[4: 4 + length].decode("utf-8") # decode string

@classmethod
def encode_bytes(cls, value: bytes) -> bytes:
Expand All @@ -171,7 +179,7 @@ def encode_bytes(cls, value: bytes) -> bytes:
@classmethod
def decode_bytes(cls, value: bytes) -> bytes:
length = struct.unpack("!I", value[:4])[0] # get length of bytes
return value[4 : 4 + length] # extract bytes
return value[4: 4 + length] # extract bytes

@classmethod
def encode_float(cls, value: float) -> bytes:
Expand All @@ -184,14 +192,15 @@ def decode_float(cls, value: bytes) -> float:

def _push_int(federation: FederationEngine, name: str, namespace: NS, parties: List[PartyMeta], value: int):
tag = namespace.federation_tag
federation.push(v=f.getvalue(), name=name, tag=tag, parties=parties)
federation.push(v=Serde.encode_int(value), name=name, tag=tag, parties=parties)


def _pull(
federation: FederationEngine,
name: str,
namespace: NS,
parties: List[PartyMeta],
ctx: "Context",
federation: FederationEngine,
name: str,
namespace: NS,
parties: List[PartyMeta],
):
tag = namespace.federation_tag
raw_values = federation.pull(
Expand All @@ -201,7 +210,7 @@ def _pull(
)
values = []
for party, buffers in zip(parties, raw_values):
values.append(_TableRmotePersistentUnpickler.pull(buffers, federation, name, tag, party))
values.append(_TableRemotePersistentUnpickler.pull(buffers, ctx, federation, name, tag, party))
return values


Expand All @@ -210,14 +219,19 @@ def __init__(self, key) -> None:
self.key = key


class _ContextPersistentId:
def __init__(self, key) -> None:
self.key = key


class _TableRemotePersistentPickler(pickle.Pickler):
def __init__(
self,
federation: FederationEngine,
name: str,
tag: str,
parties: List[PartyMeta],
f,
self,
federation: FederationEngine,
name: str,
tag: str,
parties: List[PartyMeta],
f,
) -> None:
self._federation = federation
self._name = name
Expand All @@ -233,36 +247,42 @@ def _get_next_table_key(self):
return f"{self._name}__table_persistent_{self._table_index}__"

def persistent_id(self, obj: Any) -> Any:
from fate.arch.context import Context
if is_table(obj):
key = self._get_next_table_key()
self._federation.push(v=obj, name=key, tag=self._tag, parties=self._parties)
self._table_index += 1
return _TablePersistentId(key)
if isinstance(obj, Context):
key = f"{self._name}__context__"
return _ContextPersistentId(key)

@classmethod
def push(
cls,
value,
federation: FederationEngine,
name: str,
tag: str,
parties: List[PartyMeta],
cls,
value,
federation: FederationEngine,
name: str,
tag: str,
parties: List[PartyMeta],
):
with io.BytesIO() as f:
pickler = _TableRemotePersistentPickler(federation, name, tag, parties, f)
pickler.dump(value)
federation.push(v=f.getvalue(), name=name, tag=tag, parties=parties)


class _TableRmotePersistentUnpickler(pickle.Unpickler):
class _TableRemotePersistentUnpickler(pickle.Unpickler):
def __init__(
self,
federation: FederationEngine,
name: str,
tag: str,
party: PartyMeta,
f,
self,
ctx: "Context",
federation: FederationEngine,
name: str,
tag: str,
party: PartyMeta,
f,
):
self._ctx = ctx
self._federation = federation
self._name = name
self._tag = tag
Expand All @@ -273,16 +293,19 @@ def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _TablePersistentId):
table = self._federation.pull(pid.key, self._tag, [self._party])[0]
return table
if isinstance(pid, _ContextPersistentId):
return self._ctx

@classmethod
def pull(
cls,
buffers,
federation: FederationEngine,
name: str,
tag: str,
party: PartyMeta,
cls,
buffers,
ctx: "Context",
federation: FederationEngine,
name: str,
tag: str,
party: PartyMeta,
):
with io.BytesIO(buffers) as f:
unpickler = _TableRmotePersistentUnpickler(federation, name, tag, party, f)
unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f)
return unpickler.load()