From 74b6edb739391b49c4958aadbb2f17dd152b8cd9 Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 17:33:00 +0800 Subject: [PATCH] add serdes Signed-off-by: sagewe --- python/fate/arch/computing/serdes/__init__.py | 31 +++++++ .../arch/computing/serdes/_integer_serdes.py | 13 +++ .../serdes/_restricted_caught_miss_serdes.py | 44 ++++++++++ .../computing/serdes/_restricted_serdes.py | 82 +++++++++++++++++++ .../computing/serdes/_unrestricted_serdes.py | 19 +++++ python/fate/arch/computing/table.py | 2 +- python/fate/arch/config/_config.py | 5 ++ python/fate/arch/context/_context.py | 1 + python/fate/arch/unify/serdes.py | 40 --------- 9 files changed, 196 insertions(+), 41 deletions(-) create mode 100644 python/fate/arch/computing/serdes/__init__.py create mode 100644 python/fate/arch/computing/serdes/_integer_serdes.py create mode 100644 python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py create mode 100644 python/fate/arch/computing/serdes/_restricted_serdes.py create mode 100644 python/fate/arch/computing/serdes/_unrestricted_serdes.py delete mode 100644 python/fate/arch/unify/serdes.py diff --git a/python/fate/arch/computing/serdes/__init__.py b/python/fate/arch/computing/serdes/__init__.py new file mode 100644 index 0000000000..978a69a307 --- /dev/null +++ b/python/fate/arch/computing/serdes/__init__.py @@ -0,0 +1,31 @@ +from fate.arch.config import cfg + + +def get_serdes_by_type(serdes_type: int): + if serdes_type == 0: + if cfg.safety.serdes.restricted_type == "unrestricted": + from ._unrestricted_serdes import get_unrestricted_serdes + + return get_unrestricted_serdes() + elif cfg.safety.serdes.restricted_type == "restricted": + from ._restricted_serdes import get_restricted_serdes + + return get_restricted_serdes() + elif cfg.safety.serdes.restricted_type == "restricted_catch_miss": + from ._restricted_caught_miss_serdes import get_restricted_catch_miss_serdes + + return get_restricted_catch_miss_serdes() + else: + raise ValueError(f"restricted type `{cfg.safety.serdes.restricted_type}` not supported") + elif serdes_type == 1: + from ._integer_serdes import get_integer_serdes + + return get_integer_serdes() + else: + raise ValueError(f"serdes type `{serdes_type}` not supported") + + +def dump_miss(path): + from ._restricted_caught_miss_serdes import dump_miss + + dump_miss(path) diff --git a/python/fate/arch/computing/serdes/_integer_serdes.py b/python/fate/arch/computing/serdes/_integer_serdes.py new file mode 100644 index 0000000000..0df5880012 --- /dev/null +++ b/python/fate/arch/computing/serdes/_integer_serdes.py @@ -0,0 +1,13 @@ +def get_integer_serdes(): + return IntegerSerdes() + + +class IntegerSerdes: + def __init__(self): + ... + + def serialize(self, obj) -> bytes: + return obj.to_bytes(8, "big") + + def deserialize(self, bytes) -> object: + return int.from_bytes(bytes, "big") diff --git a/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py b/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py new file mode 100644 index 0000000000..c3604399f4 --- /dev/null +++ b/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py @@ -0,0 +1,44 @@ +import io +import pickle + +from ruamel import yaml + +from ._restricted_serdes import RestrictedUnpickler + + +def get_restricted_catch_miss_serdes(): + return WhitelistCatchRestrictedSerdes + + +class WhitelistCatchRestrictedSerdes: + @classmethod + def serialize(cls, obj) -> bytes: + return pickle.dumps(obj) + + @classmethod + def deserialize(cls, bytes) -> object: + return RestrictedCatchUnpickler(io.BytesIO(bytes)).load() + + +class RestrictedCatchUnpickler(RestrictedUnpickler): + caught_miss = {} + + def find_class(self, module, name): + try: + return super().find_class(module, name) + except pickle.UnpicklingError: + if (module, name) not in self.caught_miss: + if module not in self.caught_miss: + self.caught_miss[module] = set() + self.caught_miss[module].add(name) + return self._load(module, name) + + @classmethod + def dump_miss(cls, path): + with open(path, "w") as f: + yaml.dump({module: list(names) for module, names in cls.caught_miss.items()}, f) + + +def dump_miss(path): + RestrictedCatchUnpickler.dump_miss(path) + diff --git a/python/fate/arch/computing/serdes/_restricted_serdes.py b/python/fate/arch/computing/serdes/_restricted_serdes.py new file mode 100644 index 0000000000..0c1838fc40 --- /dev/null +++ b/python/fate/arch/computing/serdes/_restricted_serdes.py @@ -0,0 +1,82 @@ +import importlib +import io +import pickle + +from ruamel import yaml + + +def get_restricted_serdes(): + return WhitelistRestrictedSerdes + + +class WhitelistRestrictedSerdes: + @classmethod + def serialize(cls, obj) -> bytes: + return pickle.dumps(obj) + + @classmethod + def deserialize(cls, bytes) -> object: + return RestrictedUnpickler(io.BytesIO(bytes)).load() + + +class RestrictedUnpickler(pickle.Unpickler): + def _load(self, module, name): + try: + return super().find_class(module, name) + except: + return getattr(importlib.import_module(module), name) + + def find_class(self, module, name): + if name in Whitelist.get_whitelist().get(module, set()): + return self._load(module, name) + else: + for m in Whitelist.get_whitelist_glob(): + if module.startswith(m): + return self._load(module, name) + raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}") + + +class Whitelist: + loaded = False + deserialize_whitelist = {} + deserialize_glob_whitelist = set() + + @classmethod + def get_whitelist_glob(cls): + if not cls.loaded: + cls.load_deserialize_whitelist() + return cls.deserialize_glob_whitelist + + @classmethod + def get_whitelist(cls): + if not cls.loaded: + cls.load_deserialize_whitelist() + return cls.deserialize_whitelist + + @classmethod + def get_whitelist_path(cls): + import os.path + + return os.path.abspath( + os.path.join( + __file__, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + "configs", + "whitelist.yaml", + ) + ) + + @classmethod + def load_deserialize_whitelist(cls): + with open(cls.get_whitelist_path()) as f: + for k, v in yaml.load(f, Loader=yaml.SafeLoader).items(): + if k.endswith("*"): + cls.deserialize_glob_whitelist.add(k[:-1]) + else: + cls.deserialize_whitelist[k] = set(v) + cls.loaded = True diff --git a/python/fate/arch/computing/serdes/_unrestricted_serdes.py b/python/fate/arch/computing/serdes/_unrestricted_serdes.py new file mode 100644 index 0000000000..595e57cebb --- /dev/null +++ b/python/fate/arch/computing/serdes/_unrestricted_serdes.py @@ -0,0 +1,19 @@ +import os +import pickle + + +def get_unrestricted_serdes(): + if True or os.environ.get("SERDES_DEBUG_MODE") == "1": + return UnrestrictedSerdes + else: + raise PermissionError("UnsafeSerdes is not allowed in production mode") + + +class UnrestrictedSerdes: + @staticmethod + def serialize(obj) -> bytes: + return pickle.dumps(obj) + + @staticmethod + def deserialize(bytes) -> object: + return pickle.loads(bytes) diff --git a/python/fate/arch/computing/table.py b/python/fate/arch/computing/table.py index ad8b124c19..36e8f68628 100644 --- a/python/fate/arch/computing/table.py +++ b/python/fate/arch/computing/table.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Tuple, Iterable, Generic, TypeVar, Optional from fate.arch.unify.partitioner import get_partitioner_by_type -from fate.arch.unify.serdes import get_serdes_by_type +from fate.arch.computing.serdes import get_serdes_by_type from fate.arch.utils.trace import auto_trace from ..unify import URI import functools diff --git a/python/fate/arch/config/_config.py b/python/fate/arch/config/_config.py index 4381ba8fa2..aa5aa63502 100644 --- a/python/fate/arch/config/_config.py +++ b/python/fate/arch/config/_config.py @@ -66,3 +66,8 @@ def get_option(self, options, key, default=...): raise ValueError(f"{key} not in {options} or {self.config}") else: return default + + @property + def safety(self): + return self.config.safety + diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 5e5c16b191..644dc78073 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -23,6 +23,7 @@ from ._namespace import NS, default_ns from ..unify import device from fate.arch.utils.trace import auto_trace +from fate.arch.config import cfg logger = logging.getLogger(__name__) diff --git a/python/fate/arch/unify/serdes.py b/python/fate/arch/unify/serdes.py deleted file mode 100644 index 9833946ab4..0000000000 --- a/python/fate/arch/unify/serdes.py +++ /dev/null @@ -1,40 +0,0 @@ -import pickle -import os - - -class UnsafeSerdes: - def __init__(self): - ... - - def serialize(self, obj) -> bytes: - return pickle.dumps(obj) - - def deserialize(self, bytes) -> object: - return pickle.loads(bytes) - - -class IntegerSerdes: - def __init__(self): - ... - - def serialize(self, obj) -> bytes: - return obj.to_bytes(8, "big") - - def deserialize(self, bytes) -> object: - return int.from_bytes(bytes, "big") - - -def get_unsafe_serdes(): - if True or os.environ.get("SERDES_DEBUG_MODE") == "1": - return UnsafeSerdes() - else: - raise PermissionError("UnsafeSerdes is not allowed in production mode") - - -def get_serdes_by_type(serdes_type: int): - if serdes_type == 0: - return get_unsafe_serdes() - elif serdes_type == 1: - return IntegerSerdes() - else: - raise ValueError(f"serdes type `{serdes_type}` not supported")