Skip to content

Commit

Permalink
fix federation for standalone, but eggroll corrupt in temp
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Oct 18, 2023
1 parent 82240c5 commit 31e25eb
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 196 deletions.
191 changes: 64 additions & 127 deletions python/fate/arch/_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import logging.config
import os
from typing import Callable, Any, Iterable, Optional
import pickle as c_pickle
import shutil
import signal
import threading
Expand All @@ -47,9 +46,6 @@ class FederationDataType(object):
SPLIT_OBJECT = "split_obj"


serialize = c_pickle.dumps
deserialize = c_pickle.loads

# default message max size in bytes = 1MB
DEFAULT_MESSAGE_MAX_SIZE = 1048576

Expand Down Expand Up @@ -409,12 +405,6 @@ def map_reduce_partitions_with_index(
shutil.rmtree(path, ignore_errors=True)
return output

def save_as(self, name, namespace, partitions=None, need_cleanup=True):
if partitions is not None and partitions != self.num_partitions:
return self._repartition(partitions=partitions, need_cleanup=True).copy_as(name, namespace, need_cleanup)

return self.copy_as(name, namespace, need_cleanup)

def copy_as(self, name, namespace, need_cleanup=True):
return self.map_reduce_partitions_with_index(
map_partition_op=lambda i, x: x,
Expand All @@ -433,7 +423,7 @@ def copy_as(self, name, namespace, need_cleanup=True):
def _get_env_for_partition(self, p: int, write=False):
return _get_env(self._namespace, self._name, str(p), write=write)

def put(self, k_bytes, v_bytes, partitioner: Callable[[bytes, int], int] = None):
def put(self, k_bytes: bytes, v_bytes: bytes, partitioner: Callable[[bytes, int], int] = None):
p = partitioner(k_bytes, self._partitions)
with self._get_env_for_partition(p, write=True) as env:
with env.begin(write=True) as txn:
Expand All @@ -459,20 +449,19 @@ def put_all(self, kv_list: Iterable[Tuple[bytes, bytes]], partitioner: Callable[
for p, (env, txn) in txn_map.items():
txn.commit()

def get(self, k_bytes: bytes, partitioner: Callable[[bytes, int], int]):
def get(self, k_bytes: bytes, partitioner: Callable[[bytes, int], int]) -> bytes:
p = partitioner(k_bytes, self._partitions)
with self._get_env_for_partition(p) as env:
with env.begin(write=True) as txn:
old_value_bytes = txn.get(k_bytes)
return None if old_value_bytes is None else deserialize(old_value_bytes)
return txn.get(k_bytes)

def delete(self, k_bytes: bytes, partitioner: Callable[[bytes, int], int]):
p = partitioner(k_bytes, self._partitions)
with self._get_env_for_partition(p, write=True) as env:
with env.begin(write=True) as txn:
old_value_bytes = txn.get(k_bytes)
if txn.delete(k_bytes):
return None if old_value_bytes is None else deserialize(old_value_bytes)
return old_value_bytes
return None


Expand Down Expand Up @@ -662,21 +651,9 @@ def _submit_process(self, do_func, process_infos):
return results


def _get_splits(obj, max_message_size):
obj_bytes = serialize(obj, protocol=4)
byte_size = len(obj_bytes)
num_slice = (byte_size - 1) // max_message_size + 1
if num_slice <= 1:
return obj, num_slice
else:
_max_size = max_message_size
kv = [(serialize(i), obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
return kv, num_slice


class Federation(object):
def _federation_object_key(self, name: str, tag: str, s_party: Tuple[str, str], d_party: Tuple[str, str]):
return f"{self._session_id}-{name}-{tag}-{s_party[0]}-{s_party[1]}-{d_party[0]}-{d_party[1]}"
def _federation_object_key(self, name: str, tag: str, s_party: Tuple[str, str], d_party: Tuple[str, str]) -> bytes:
return f"{self._session_id}-{name}-{tag}-{s_party[0]}-{s_party[1]}-{d_party[0]}-{d_party[1]}".encode("utf-8")

def __init__(self, session: Session, session_id: str, party: Tuple[str, str]):
self._session_id = session_id
Expand All @@ -693,91 +670,46 @@ def __init__(self, session: Session, session_id: str, party: Tuple[str, str]):
def destroy(self):
self._session.cleanup(namespace=self._session_id, name="*")

# noinspection PyUnusedLocal
def remote(self, v, name: str, tag: str, parties: List[PartyMeta]):
log_str = f"federation.standalone.remote.{name}.{tag}"

if v is None:
raise ValueError(f"[{log_str}]remote `None` to {parties}")

LOGGER.debug(f"[{log_str}]remote data, type={type(v)}")

if isinstance(v, Table):
dtype = FederationDataType.TABLE
LOGGER.debug(
f"[{log_str}]remote "
f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
)
else:
v_splits, num_slice = _get_splits(v, self._max_message_size)
if num_slice > 1:
v = _create_table(
session=self._session,
name=str(uuid.uuid1()),
namespace=self._session_id,
partitions=1,
need_cleanup=True,
error_if_exist=False,
)
v.put_all(kv_list=v_splits)
dtype = FederationDataType.SPLIT_OBJECT
LOGGER.debug(
f"[{log_str}]remote "
f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
)
else:
LOGGER.debug(f"[{log_str}]remote object with type: {type(v)}")
dtype = FederationDataType.OBJECT

def push_table(self, table, name: str, tag: str, parties: List[PartyMeta]):
for party in parties:
_tagged_key = self._federation_object_key(name, tag, self._party, party)
if isinstance(v, Table):
saved_name = str(uuid.uuid1())
LOGGER.debug(
f"[{log_str}]save Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}) as "
f"Table(namespace={v.namespace}, name={saved_name}, partitions={v.partitions})"
)
_v = v.copy_as(name=saved_name, namespace=v.namespace, need_cleanup=False)
self._meta.set_status(party, _tagged_key, (_v.name, _v.namespace, dtype))
else:
self._meta.set_object(party, _tagged_key, v)
self._meta.set_status(party, _tagged_key, _tagged_key)
saved_name = str(uuid.uuid1())
_table = table.copy_as(name=saved_name, namespace=table.namespace, need_cleanup=False)
self._meta.set_status(party, _tagged_key, _serialize_tuple_of_str(_table.name, _table.namespace))

# noinspection PyProtectedMember
def get(self, name: str, tag: str, parties: List[PartyMeta]) -> List:
log_str = f"federation.standalone.get.{name}.{tag}"
LOGGER.debug(f"[{log_str}]")
results = []
def push_bytes(self, v: bytes, name: str, tag: str, parties: List[PartyMeta]):
for party in parties:
_tagged_key = self._federation_object_key(name, tag, self._party, party)
self._meta.set_object(party, _tagged_key, v)
self._meta.set_status(party, _tagged_key, _tagged_key)

def pull_table(self, name: str, tag: str, parties: List[PartyMeta]) -> List[Table]:
results: List[bytes] = []
for party in parties:
_tagged_key = self._federation_object_key(name, tag, party, self._party)
results.append(self._meta.wait_status_set(_tagged_key))

rtn = []
for r in results:
if isinstance(r, tuple):
# noinspection PyTypeChecker
table: Table = _load_table(session=self._session, name=r[0], namespace=r[1], need_cleanup=True)

dtype = r[2]
LOGGER.debug(
f"[{log_str}] got "
f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}"
)
name, namespace = _deserialize_tuple_of_str(self._meta.get_status(r))
table: Table = _load_table(session=self._session, name=name, namespace=namespace, need_cleanup=True)
rtn.append(table)
self._meta.ack_status(r)
return rtn

if dtype == FederationDataType.SPLIT_OBJECT:
obj_bytes = b"".join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
obj = deserialize(obj_bytes)
rtn.append(obj)
else:
rtn.append(table)
else:
obj = self._meta.get_object(r)
if obj is None:
raise EnvironmentError(f"federation get None from {parties} with name {name}, tag {tag}")
rtn.append(obj)
self._meta.ack_object(r)
LOGGER.debug(f"[{log_str}] got object with type: {type(obj)}")
def pull_bytes(self, name: str, tag: str, parties: List[PartyMeta]) -> List[bytes]:
results = []
for party in parties:
_tagged_key = self._federation_object_key(name, tag, party, self._party)
results.append(self._meta.wait_status_set(_tagged_key))

rtn = []
for r in results:
obj = self._meta.get_object(r)
if obj is None:
raise EnvironmentError(f"object not found: {r}")
rtn.append(obj)
self._meta.ack_object(r)
self._meta.ack_status(r)
return rtn

Expand Down Expand Up @@ -1179,30 +1111,29 @@ def __init__(self, session_id, party: Tuple[str, str]) -> None:
self.party = party
self._env = {}

def wait_status_set(self, key):
def wait_status_set(self, key: bytes) -> bytes:
value = self.get_status(key)
while value is None:
time.sleep(0.1)
value = self.get_status(key)
LOGGER.debug("[GET] Got {} type {}".format(key, "Table" if isinstance(value, tuple) else "Object"))
return value
return key

def get_status(self, key):
def get_status(self, key: bytes):
return self._get(self._get_status_table_name(self.party), key)

def set_status(self, party: Tuple[str, str], key: str, value):
def set_status(self, party: Tuple[str, str], key: bytes, value: bytes):
return self._set(self._get_status_table_name(party), key, value)

def ack_status(self, key):
def ack_status(self, key: bytes):
return self._ack(self._get_status_table_name(self.party), key)

def get_object(self, key):
def get_object(self, key: bytes):
return self._get(self._get_object_table_name(self.party), key)

def set_object(self, party: Tuple[str, str], key, value):
def set_object(self, party: Tuple[str, str], key: bytes, value: bytes):
return self._set(self._get_object_table_name(party), key, value)

def ack_object(self, key):
def ack_object(self, key: bytes):
return self._ack(self._get_object_table_name(self.party), key)

def _get_status_table_name(self, party: Tuple[str, str]):
Expand All @@ -1216,23 +1147,20 @@ def _get_env(self, name):
self._env[name] = _get_env(self.session_id, name, str(0), write=True)
return self._env[name]

def _get(self, name, key):
def _get(self, name: str, key: bytes) -> bytes:
env = self._get_env(name)
with env.begin(write=False) as txn:
old_value_bytes = txn.get(serialize(key))
if old_value_bytes is not None:
old_value_bytes = deserialize(old_value_bytes)
return old_value_bytes
return txn.get(key)

def _set(self, name, key, value):
def _set(self, name, key: bytes, value: bytes):
env = self._get_env(name)
with env.begin(write=True) as txn:
return txn.put(serialize(key), serialize(value))
return txn.put(key, value)

def _ack(self, name, key):
def _ack(self, name, key: bytes):
env = self._get_env(name)
with env.begin(write=True) as txn:
txn.delete(serialize(key))
txn.delete(key)


def _hash_namespace_name_to_partition(namespace: str, name: str, partitions: int) -> Tuple[bytes, int]:
Expand Down Expand Up @@ -1280,12 +1208,7 @@ def get_table_meta(cls, namespace: str, name: str) -> "_TableMeta":
with env.begin(write=False) as txn:
old_value_bytes = txn.get(k_bytes)
if old_value_bytes is not None:
try:
num_partitions = deserialize(old_value_bytes)
old_value_bytes = _TableMeta(num_partitions, 0, 0, 0)
except Exception:
old_value_bytes = _TableMeta.deserialize(old_value_bytes)

old_value_bytes = _TableMeta.deserialize(old_value_bytes)
return old_value_bytes

@classmethod
Expand Down Expand Up @@ -1318,3 +1241,17 @@ def deserialize(cls, serialized_bytes: bytes) -> "_TableMeta":
value_serdes_type = int.from_bytes(serialized_bytes[8:12], "big")
partitioner_type = int.from_bytes(serialized_bytes[12:16], "big")
return cls(num_partitions, key_serdes_type, value_serdes_type, partitioner_type)


def _serialize_tuple_of_str(name: str, namespace: str):
name_bytes = name.encode("utf-8")
namespace_bytes = namespace.encode("utf-8")
split_index_bytes = len(name_bytes).to_bytes(4, "big")
return split_index_bytes + name_bytes + namespace_bytes


def _deserialize_tuple_of_str(serialized_bytes: bytes):
split_index = int.from_bytes(serialized_bytes[:4], "big")
name = serialized_bytes[4 : 4 + split_index].decode("utf-8")
namespace = serialized_bytes[4 + split_index :].decode("utf-8")
return name, namespace
4 changes: 1 addition & 3 deletions python/fate/arch/computing/standalone/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ def _load(
raise ValueError(f"uri `{uri}` not valid, demo format: standalone://database_path/namespace/name") from e

raw_table = self._session.load(name=name, namespace=namespace)
partitions = raw_table.partitions
raw_table = raw_table.save_as(
raw_table = raw_table.copy_as(
name=f"{name}_{uuid()}",
namespace=namespace,
partitions=partitions,
need_cleanup=True,
)
table = Table(raw_table)
Expand Down
20 changes: 10 additions & 10 deletions python/fate/arch/computing/standalone/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def __init__(self, table: StandaloneTable):
num_partitions=table.partitions,
)

@property
def table(self):
return self._table

@property
def partitions(self):
return self._table.partitions

@property
def engine(self):
return self._engine
Expand Down Expand Up @@ -116,25 +124,17 @@ def _count(self):
def _reduce(self, func, **kwargs):
return self._table.reduce(func)

@property
def partitions(self):
return self._table.partitions

@computing_profile
def save(self, uri: URI, schema, options: dict = None):
if options is None:
options = {}

def _save(self, uri: URI, schema, options: dict = None):
if uri.scheme != "standalone":
raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend")
try:
*database, namespace, name = uri.path_splits()
except Exception as e:
raise ValueError(f"uri `{uri}` not supported with standalone backend") from e
self._table.save_as(
self._table.copy_as(
name=name,
namespace=namespace,
partitions=options.get("partitions", self.partitions),
need_cleanup=False,
)
# TODO: self.schema is a bit confusing here, it set by property assignment directly, not by constructor
Expand Down
10 changes: 8 additions & 2 deletions python/fate/arch/computing/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,14 @@ def repartition_with(self, other: "KVTable") -> Tuple["KVTable", "KVTable"]:
else:
return self.repartition(other.num_partitions, other.partitioner_type), other

# def save_as(self, name, namespace, partition=None, options=None):
# return self.rp.save_as(name=name, namespace=namespace, partition=partition, options=options)
def save(self, uri: URI, schema, options: dict = None):
options = options or {}
if (partition := options.get("partition")) is not None and partition != self.num_partitions:
self.repartition(partition)._save(uri, schema, options)
return self._save(uri, schema, options)

def _save(self, uri: URI, schema, options: dict = None):
raise NotImplementedError(f"{self.__class__.__name__}._save")


def _serdes_wrapped_generator(_iter, key_serdes, value_serdes):
Expand Down
Loading

0 comments on commit 31e25eb

Please sign in to comment.