Skip to content

Commit

Permalink
add serdes
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 7, 2023
1 parent d7fd807 commit 74b6edb
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 41 deletions.
31 changes: 31 additions & 0 deletions python/fate/arch/computing/serdes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fate.arch.config import cfg


def get_serdes_by_type(serdes_type: int):
if serdes_type == 0:
if cfg.safety.serdes.restricted_type == "unrestricted":
from ._unrestricted_serdes import get_unrestricted_serdes

return get_unrestricted_serdes()
elif cfg.safety.serdes.restricted_type == "restricted":
from ._restricted_serdes import get_restricted_serdes

return get_restricted_serdes()
elif cfg.safety.serdes.restricted_type == "restricted_catch_miss":
from ._restricted_caught_miss_serdes import get_restricted_catch_miss_serdes

return get_restricted_catch_miss_serdes()
else:
raise ValueError(f"restricted type `{cfg.safety.serdes.restricted_type}` not supported")
elif serdes_type == 1:
from ._integer_serdes import get_integer_serdes

return get_integer_serdes()
else:
raise ValueError(f"serdes type `{serdes_type}` not supported")


def dump_miss(path):
from ._restricted_caught_miss_serdes import dump_miss

dump_miss(path)
13 changes: 13 additions & 0 deletions python/fate/arch/computing/serdes/_integer_serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def get_integer_serdes():
return IntegerSerdes()


class IntegerSerdes:
def __init__(self):
...

def serialize(self, obj) -> bytes:
return obj.to_bytes(8, "big")

def deserialize(self, bytes) -> object:
return int.from_bytes(bytes, "big")
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import io
import pickle

from ruamel import yaml

from ._restricted_serdes import RestrictedUnpickler


def get_restricted_catch_miss_serdes():
return WhitelistCatchRestrictedSerdes


class WhitelistCatchRestrictedSerdes:
@classmethod
def serialize(cls, obj) -> bytes:
return pickle.dumps(obj)

@classmethod
def deserialize(cls, bytes) -> object:
return RestrictedCatchUnpickler(io.BytesIO(bytes)).load()


class RestrictedCatchUnpickler(RestrictedUnpickler):
caught_miss = {}

def find_class(self, module, name):
try:
return super().find_class(module, name)
except pickle.UnpicklingError:
if (module, name) not in self.caught_miss:
if module not in self.caught_miss:
self.caught_miss[module] = set()
self.caught_miss[module].add(name)
return self._load(module, name)

@classmethod
def dump_miss(cls, path):
with open(path, "w") as f:
yaml.dump({module: list(names) for module, names in cls.caught_miss.items()}, f)


def dump_miss(path):
RestrictedCatchUnpickler.dump_miss(path)

82 changes: 82 additions & 0 deletions python/fate/arch/computing/serdes/_restricted_serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import importlib
import io
import pickle

from ruamel import yaml


def get_restricted_serdes():
return WhitelistRestrictedSerdes


class WhitelistRestrictedSerdes:
@classmethod
def serialize(cls, obj) -> bytes:
return pickle.dumps(obj)

@classmethod
def deserialize(cls, bytes) -> object:
return RestrictedUnpickler(io.BytesIO(bytes)).load()


class RestrictedUnpickler(pickle.Unpickler):
def _load(self, module, name):
try:
return super().find_class(module, name)
except:
return getattr(importlib.import_module(module), name)

def find_class(self, module, name):
if name in Whitelist.get_whitelist().get(module, set()):
return self._load(module, name)
else:
for m in Whitelist.get_whitelist_glob():
if module.startswith(m):
return self._load(module, name)
raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}")


class Whitelist:
loaded = False
deserialize_whitelist = {}
deserialize_glob_whitelist = set()

@classmethod
def get_whitelist_glob(cls):
if not cls.loaded:
cls.load_deserialize_whitelist()
return cls.deserialize_glob_whitelist

@classmethod
def get_whitelist(cls):
if not cls.loaded:
cls.load_deserialize_whitelist()
return cls.deserialize_whitelist

@classmethod
def get_whitelist_path(cls):
import os.path

return os.path.abspath(
os.path.join(
__file__,
os.path.pardir,
os.path.pardir,
os.path.pardir,
os.path.pardir,
os.path.pardir,
os.path.pardir,
"configs",
"whitelist.yaml",
)
)

@classmethod
def load_deserialize_whitelist(cls):
with open(cls.get_whitelist_path()) as f:
for k, v in yaml.load(f, Loader=yaml.SafeLoader).items():
if k.endswith("*"):
cls.deserialize_glob_whitelist.add(k[:-1])
else:
cls.deserialize_whitelist[k] = set(v)
cls.loaded = True
19 changes: 19 additions & 0 deletions python/fate/arch/computing/serdes/_unrestricted_serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import pickle


def get_unrestricted_serdes():
if True or os.environ.get("SERDES_DEBUG_MODE") == "1":
return UnrestrictedSerdes
else:
raise PermissionError("UnsafeSerdes is not allowed in production mode")


class UnrestrictedSerdes:
@staticmethod
def serialize(obj) -> bytes:
return pickle.dumps(obj)

@staticmethod
def deserialize(bytes) -> object:
return pickle.loads(bytes)
2 changes: 1 addition & 1 deletion python/fate/arch/computing/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, Tuple, Iterable, Generic, TypeVar, Optional

from fate.arch.unify.partitioner import get_partitioner_by_type
from fate.arch.unify.serdes import get_serdes_by_type
from fate.arch.computing.serdes import get_serdes_by_type
from fate.arch.utils.trace import auto_trace
from ..unify import URI
import functools
Expand Down
5 changes: 5 additions & 0 deletions python/fate/arch/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@ def get_option(self, options, key, default=...):
raise ValueError(f"{key} not in {options} or {self.config}")
else:
return default

@property
def safety(self):
return self.config.safety

1 change: 1 addition & 0 deletions python/fate/arch/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._namespace import NS, default_ns
from ..unify import device
from fate.arch.utils.trace import auto_trace
from fate.arch.config import cfg

logger = logging.getLogger(__name__)

Expand Down
40 changes: 0 additions & 40 deletions python/fate/arch/unify/serdes.py

This file was deleted.

0 comments on commit 74b6edb

Please sign in to comment.