Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev-2.0.0-rc' into dev-2.0.0-rc
Browse files Browse the repository at this point in the history
  • Loading branch information
talkingwallace committed Dec 14, 2023
2 parents fdd997b + 8821d64 commit da61b7f
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 133 deletions.
8 changes: 7 additions & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ safety:
debug_mode: False
validation_mode: False
encoder:
precision_bits: 24
precision_bits: 24

federation:
split_large_object:
enable: True
max_message_size: 1048576
partition_num: 4
20 changes: 14 additions & 6 deletions python/fate/arch/computing/backends/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,20 @@ def _parallelize(
)
return Table(table)

def _info(self):
return {
"session_id": self.session_id,
"data_dir": self._session.data_dir,
"max_workers": self._session.max_workers,
}
def _info(self, level=0):

if level == 0:
return f"Standalone<session_id={self.session_id}, max_workers={self._session.max_workers}, data_dir={self._session.data_dir}>"

elif level == 1:
import inspect

return {
"session_id": self.session_id,
"data_dir": self._session.data_dir,
"max_workers": self._session.max_workers,
"code_path": inspect.getfile(self._session.__class__),
}

def cleanup(self, name, namespace):
return self._session.cleanup(name=name, namespace=namespace)
Expand Down
2 changes: 2 additions & 0 deletions python/fate/arch/computing/backends/standalone/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ def __init__(self, session_id, data_dir: str, max_workers=None, logger_config=No
self.session_id = session_id
self._data_dir = data_dir
self._max_workers = max_workers
if self._max_workers is None:
self._max_workers = os.cpu_count()
self._pool = Executor(
max_workers=max_workers,
initializer=_watch_thread_react_to_parent_die,
Expand Down
16 changes: 4 additions & 12 deletions python/fate/arch/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,6 @@ def __init__(self, config_file=None):
def mpc(self):
return self.config.mpc

@property
def debug(self):
return self.config.debug

@property
def encoder(self):
return self.config.encoder

@property
def functions(self):
return self.config.functions

@contextmanager
def temp_override(self, override_dict):
old_config = self.config
Expand All @@ -67,6 +55,10 @@ def get_option(self, options, key, default=...):
else:
return default

@property
def federation(self):
return self.config.federation

@property
def safety(self):
return self.config.safety
Expand Down
1 change: 1 addition & 0 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def _get_parties(self, role: Optional[Literal["guest", "host", "arbiter"]] = Non
return Parties(
self,
self._get_federation(),
self._get_computing(),
parties,
self._namespace,
)
Expand Down
153 changes: 129 additions & 24 deletions python/fate/arch/context/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import typing
from typing import Any, List, Tuple, TypeVar, Union

from fate.arch.federation.api import PartyMeta
from fate.arch.computing.api import is_table
from fate.arch.federation.api import PartyMeta
from ._namespace import NS

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,7 @@
if typing.TYPE_CHECKING:
from fate.arch.context import Context
from fate.arch.federation.api import Federation
from fate.arch.computing.api import KVTableContext


class _KeyedParty:
Expand All @@ -44,10 +45,20 @@ def get(self):


class Party:
def __init__(self, ctx: "Context", federation, party: PartyMeta, rank: int, namespace: NS, key=None) -> None:
def __init__(
self,
ctx: "Context",
federation: "Federation",
computing: "KVTableContext",
party: PartyMeta,
rank: int,
namespace: NS,
key=None,
) -> None:
self._ctx = ctx
self._party = party
self.federation = federation
self.computing = computing
self.rank = rank
self.namespace = namespace

Expand Down Expand Up @@ -82,7 +93,16 @@ def put(self, *args, **kwargs):
kvs = kwargs.items()

for k, v in kvs:
return _push(self.federation, k, self.namespace, [self.party], v)
return _push(
federation=self.federation,
computing=self.computing,
name=k,
namespace=self.namespace,
parties=[self.party],
value=v,
max_message_size=self.federation.get_default_max_message_size(),
num_partitions_of_slice_table=self.federation.get_default_partition_num(),
)

def get(self, name: str):
return _pull(self._ctx, self.federation, name, self.namespace, [self.party])[0]
Expand All @@ -96,11 +116,13 @@ def __init__(
self,
ctx: "Context",
federation: "Federation",
computing: "KVTableContext",
parties: List[Tuple[int, PartyMeta]],
namespace: NS,
) -> None:
self._ctx = ctx
self.federation = federation
self.computing = computing
self.parties = parties
self.namespace = namespace

Expand All @@ -113,10 +135,15 @@ def ranks(self):

def __getitem__(self, key: int) -> Party:
rank, party = self.parties[key]
return Party(self._ctx, self.federation, party, rank, self.namespace)
return Party(self._ctx, self.federation, self.computing, party, rank, self.namespace)

def __iter__(self):
return iter([Party(self._ctx, self.federation, party, rank, self.namespace) for rank, party in self.parties])
return iter(
[
Party(self._ctx, self.federation, self.computing, party, rank, self.namespace)
for rank, party in self.parties
]
)

def __len__(self) -> int:
return len(self.parties)
Expand All @@ -132,21 +159,35 @@ def put(self, *args, **kwargs):
else:
kvs = kwargs.items()
for k, v in kvs:
return _push(self.federation, k, self.namespace, [p[1] for p in self.parties], v)
return _push(
federation=self.federation,
computing=self.computing,
name=k,
namespace=self.namespace,
parties=[p[1] for p in self.parties],
value=v,
max_message_size=self.federation.get_default_max_message_size(),
num_partitions_of_slice_table=self.federation.get_default_partition_num(),
)

def get(self, name: str):
return _pull(self._ctx, self.federation, name, self.namespace, [p[1] for p in self.parties])


def _push(
federation: "Federation",
computing: "KVTableContext",
name: str,
namespace: NS,
parties: List[PartyMeta],
value,
max_message_size,
num_partitions_of_slice_table,
):
tag = namespace.federation_tag
_TableRemotePersistentPickler.push(value, federation, name, tag, parties)
_TableRemotePersistentPickler.push(
value, federation, computing, name, tag, parties, max_message_size, num_partitions_of_slice_table
)


class Serde:
Expand Down Expand Up @@ -186,11 +227,6 @@ def decode_float(cls, value: bytes) -> float:
return struct.unpack("!d", value)[0]


def _push_int(federation: "Federation", name: str, namespace: NS, parties: List[PartyMeta], value: int):
tag = namespace.federation_tag
federation.push(v=Serde.encode_int(value), name=name, tag=tag, parties=parties)


def _pull(
ctx: "Context",
federation: "Federation",
Expand Down Expand Up @@ -220,6 +256,35 @@ def __init__(self, key) -> None:
self.key = key


class _FederationBytesCoder:
@staticmethod
def encode_base(v: bytes) -> bytes:
return struct.pack("!B", 0) + v

@staticmethod
def encode_split(total_size: int, num_slice: int, slice_size: int) -> bytes:
return struct.pack("!B", 1) + struct.pack("!III", total_size, num_slice, slice_size)

@classmethod
def decode_mode(cls, v: bytes) -> int:
return struct.unpack("!B", v[:1])[0]

@classmethod
def decode_base(cls, v: bytes) -> bytes:
return v[1:]

@classmethod
def decode_split(cls, v: bytes) -> Tuple[int, int, int]:
total_size, num_slice, slice_size = struct.unpack("!III", v[1:])
return total_size, num_slice, slice_size


class _SplitTableUtil:
@staticmethod
def get_split_table_key(name):
return f"{name}__table_persistent_split__"


class _TableRemotePersistentPickler(pickle.Pickler):
def __init__(
self,
Expand Down Expand Up @@ -247,26 +312,53 @@ def persistent_id(self, obj: Any) -> Any:

if is_table(obj):
key = self._get_next_table_key()
self._federation.push_table(table=obj, name=key, tag=self._tag, parties=self._parties)
self._table_index += 1
return _TablePersistentId(key)
return _TablePersistentId(self._push_table(obj, key))
if isinstance(obj, Context):
key = f"{self._name}__context__"
return _ContextPersistentId(key)

def _push_table(self, table, key):
self._federation.push_table(table=table, name=key, tag=self._tag, parties=self._parties)
self._table_index += 1
return key

@classmethod
def push(
cls,
value,
federation: "Federation",
computing: "KVTableContext",
name: str,
tag: str,
parties: List[PartyMeta],
max_message_size: int,
num_partitions_of_slice_table: int,
):
with io.BytesIO() as f:
pickler = _TableRemotePersistentPickler(federation, name, tag, parties, f)
pickler.dump(value)
federation.push_bytes(v=f.getvalue(), name=name, tag=tag, parties=parties)
if f.tell() > max_message_size:
total_size = f.tell()
num_slice = (total_size - 1) // max_message_size + 1
# create a table to store the slice
f.seek(0)
slice_table = computing.parallelize(
((i, f.read(max_message_size)) for i in range(num_slice)), partition=num_partitions_of_slice_table
)
# push the slice table with a special key
pickler._push_table(slice_table, _SplitTableUtil.get_split_table_key(name))
# push the slice table info
federation.push_bytes(
v=_FederationBytesCoder.encode_split(total_size, num_slice, max_message_size),
name=name,
tag=tag,
parties=parties,
)

else:
federation.push_bytes(
v=_FederationBytesCoder.encode_base(f.getvalue()), name=name, tag=tag, parties=parties
)


class _TableRemotePersistentUnpickler(pickle.Unpickler):
Expand All @@ -293,11 +385,6 @@ def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _ContextPersistentId):
return self._ctx

# def load(self):
# out = super().load()
# logger.error(f"unpickled: {out.__class__.__module__}.{out.__class__.__name__}")
# return out

@classmethod
def pull(
cls,
Expand All @@ -308,6 +395,24 @@ def pull(
tag: str,
party: PartyMeta,
):
with io.BytesIO(buffers) as f:
unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f)
return unpickler.load()
mode = _FederationBytesCoder.decode_mode(buffers)
if mode == 0:
with io.BytesIO(_FederationBytesCoder.decode_base(buffers)) as f:
unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f)
return unpickler.load()
elif mode == 1:
# get num_slice and slice_size
total_size, num_slice, slice_size = _FederationBytesCoder.decode_split(buffers)

# pull the slice table with a special key
slice_table = federation.pull_table(_SplitTableUtil.get_split_table_key(name), tag, [party])[0]
# merge the bytes
with io.BytesIO() as f:
for i, b in slice_table.collect():
f.seek(i * slice_size)
f.write(b)
f.seek(0)
unpickler = _TableRemotePersistentUnpickler(ctx, federation, name, tag, party, f)
return unpickler.load()
else:
raise ValueError(f"invalid mode: {mode}")
10 changes: 10 additions & 0 deletions python/fate/arch/federation/api/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def __init__(self, session_id: str, party: PartyMeta, parties: List[PartyMeta]):
self._get_gc: GarbageCollector = GarbageCollector()
self._remote_gc: GarbageCollector = GarbageCollector()

def get_default_max_message_size(self):
from fate.arch.config import cfg

return cfg.federation.split_large_object.max_message_size

def get_default_partition_num(self):
from fate.arch.config import cfg

return cfg.federation.split_large_object.partition_num

@property
def session_id(self) -> str:
return self._session_id
Expand Down
Loading

0 comments on commit da61b7f

Please sign in to comment.