-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: sagewe <wbwmat@gmail.com>
- Loading branch information
Showing
9 changed files
with
196 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fate.arch.config import cfg | ||
|
||
|
||
def get_serdes_by_type(serdes_type: int): | ||
if serdes_type == 0: | ||
if cfg.safety.serdes.restricted_type == "unrestricted": | ||
from ._unrestricted_serdes import get_unrestricted_serdes | ||
|
||
return get_unrestricted_serdes() | ||
elif cfg.safety.serdes.restricted_type == "restricted": | ||
from ._restricted_serdes import get_restricted_serdes | ||
|
||
return get_restricted_serdes() | ||
elif cfg.safety.serdes.restricted_type == "restricted_catch_miss": | ||
from ._restricted_caught_miss_serdes import get_restricted_catch_miss_serdes | ||
|
||
return get_restricted_catch_miss_serdes() | ||
else: | ||
raise ValueError(f"restricted type `{cfg.safety.serdes.restricted_type}` not supported") | ||
elif serdes_type == 1: | ||
from ._integer_serdes import get_integer_serdes | ||
|
||
return get_integer_serdes() | ||
else: | ||
raise ValueError(f"serdes type `{serdes_type}` not supported") | ||
|
||
|
||
def dump_miss(path): | ||
from ._restricted_caught_miss_serdes import dump_miss | ||
|
||
dump_miss(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
def get_integer_serdes(): | ||
return IntegerSerdes() | ||
|
||
|
||
class IntegerSerdes: | ||
def __init__(self): | ||
... | ||
|
||
def serialize(self, obj) -> bytes: | ||
return obj.to_bytes(8, "big") | ||
|
||
def deserialize(self, bytes) -> object: | ||
return int.from_bytes(bytes, "big") |
44 changes: 44 additions & 0 deletions
44
python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import io | ||
import pickle | ||
|
||
from ruamel import yaml | ||
|
||
from ._restricted_serdes import RestrictedUnpickler | ||
|
||
|
||
def get_restricted_catch_miss_serdes(): | ||
return WhitelistCatchRestrictedSerdes | ||
|
||
|
||
class WhitelistCatchRestrictedSerdes: | ||
@classmethod | ||
def serialize(cls, obj) -> bytes: | ||
return pickle.dumps(obj) | ||
|
||
@classmethod | ||
def deserialize(cls, bytes) -> object: | ||
return RestrictedCatchUnpickler(io.BytesIO(bytes)).load() | ||
|
||
|
||
class RestrictedCatchUnpickler(RestrictedUnpickler): | ||
caught_miss = {} | ||
|
||
def find_class(self, module, name): | ||
try: | ||
return super().find_class(module, name) | ||
except pickle.UnpicklingError: | ||
if (module, name) not in self.caught_miss: | ||
if module not in self.caught_miss: | ||
self.caught_miss[module] = set() | ||
self.caught_miss[module].add(name) | ||
return self._load(module, name) | ||
|
||
@classmethod | ||
def dump_miss(cls, path): | ||
with open(path, "w") as f: | ||
yaml.dump({module: list(names) for module, names in cls.caught_miss.items()}, f) | ||
|
||
|
||
def dump_miss(path): | ||
RestrictedCatchUnpickler.dump_miss(path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import importlib | ||
import io | ||
import pickle | ||
|
||
from ruamel import yaml | ||
|
||
|
||
def get_restricted_serdes(): | ||
return WhitelistRestrictedSerdes | ||
|
||
|
||
class WhitelistRestrictedSerdes: | ||
@classmethod | ||
def serialize(cls, obj) -> bytes: | ||
return pickle.dumps(obj) | ||
|
||
@classmethod | ||
def deserialize(cls, bytes) -> object: | ||
return RestrictedUnpickler(io.BytesIO(bytes)).load() | ||
|
||
|
||
class RestrictedUnpickler(pickle.Unpickler): | ||
def _load(self, module, name): | ||
try: | ||
return super().find_class(module, name) | ||
except: | ||
return getattr(importlib.import_module(module), name) | ||
|
||
def find_class(self, module, name): | ||
if name in Whitelist.get_whitelist().get(module, set()): | ||
return self._load(module, name) | ||
else: | ||
for m in Whitelist.get_whitelist_glob(): | ||
if module.startswith(m): | ||
return self._load(module, name) | ||
raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}") | ||
|
||
|
||
class Whitelist: | ||
loaded = False | ||
deserialize_whitelist = {} | ||
deserialize_glob_whitelist = set() | ||
|
||
@classmethod | ||
def get_whitelist_glob(cls): | ||
if not cls.loaded: | ||
cls.load_deserialize_whitelist() | ||
return cls.deserialize_glob_whitelist | ||
|
||
@classmethod | ||
def get_whitelist(cls): | ||
if not cls.loaded: | ||
cls.load_deserialize_whitelist() | ||
return cls.deserialize_whitelist | ||
|
||
@classmethod | ||
def get_whitelist_path(cls): | ||
import os.path | ||
|
||
return os.path.abspath( | ||
os.path.join( | ||
__file__, | ||
os.path.pardir, | ||
os.path.pardir, | ||
os.path.pardir, | ||
os.path.pardir, | ||
os.path.pardir, | ||
os.path.pardir, | ||
"configs", | ||
"whitelist.yaml", | ||
) | ||
) | ||
|
||
@classmethod | ||
def load_deserialize_whitelist(cls): | ||
with open(cls.get_whitelist_path()) as f: | ||
for k, v in yaml.load(f, Loader=yaml.SafeLoader).items(): | ||
if k.endswith("*"): | ||
cls.deserialize_glob_whitelist.add(k[:-1]) | ||
else: | ||
cls.deserialize_whitelist[k] = set(v) | ||
cls.loaded = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
import pickle | ||
|
||
|
||
def get_unrestricted_serdes(): | ||
if True or os.environ.get("SERDES_DEBUG_MODE") == "1": | ||
return UnrestrictedSerdes | ||
else: | ||
raise PermissionError("UnsafeSerdes is not allowed in production mode") | ||
|
||
|
||
class UnrestrictedSerdes: | ||
@staticmethod | ||
def serialize(obj) -> bytes: | ||
return pickle.dumps(obj) | ||
|
||
@staticmethod | ||
def deserialize(bytes) -> object: | ||
return pickle.loads(bytes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.