From 35c091db90845280cc05beaa1716a427557aa0be Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 21:31:18 +0800 Subject: [PATCH] add serdes Signed-off-by: sagewe --- configs/default.yaml | 5 ++ configs/whitelist.yaml | 1 + .../arch/computing/serdes/_safe_serdes.py | 83 +++++++++++++++++++ python/fate/arch/context/_federation.py | 7 ++ 4 files changed, 96 insertions(+) create mode 100644 configs/whitelist.yaml create mode 100644 python/fate/arch/computing/serdes/_safe_serdes.py diff --git a/configs/default.yaml b/configs/default.yaml index 3f8b37f881..438828fef2 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -45,3 +45,8 @@ nn: protocol: "layer_estimation" skip_loss_forward: True cache_pred_size: True + +safety: + serdes: + # supported types: unrestricted, restricted, restricted_catch_miss + restricted_type: "unrestricted" \ No newline at end of file diff --git a/configs/whitelist.yaml b/configs/whitelist.yaml new file mode 100644 index 0000000000..e5873408d4 --- /dev/null +++ b/configs/whitelist.yaml @@ -0,0 +1 @@ +fate: "*" \ No newline at end of file diff --git a/python/fate/arch/computing/serdes/_safe_serdes.py b/python/fate/arch/computing/serdes/_safe_serdes.py new file mode 100644 index 0000000000..fd1059e9c9 --- /dev/null +++ b/python/fate/arch/computing/serdes/_safe_serdes.py @@ -0,0 +1,83 @@ +import enum +import struct +from functools import singledispatch + + +class SerdeObjectTypes(enum.IntEnum): + INT = 0 + FLOAT = 1 + STRING = 2 + BYTES = 3 + LIST = 4 + DICT = 5 + TUPLE = 6 + + +_deserializer_registry = {} + + +def _register_deserializer(obj_type_enum): + def _register(deserializer_func): + _deserializer_registry[obj_type_enum] = deserializer_func + return deserializer_func + + return _register + + +def _dispatch_deserializer(obj_type_enum): + return _deserializer_registry[obj_type_enum] + + +class SafeSerdes(object): + @staticmethod + def serialize(obj): + obj_type, obj_bytes = serialize_obj(obj) + return struct.pack("!h", obj_type) + obj_bytes + + @staticmethod + def deserialize(raw_bytes): + (obj_type,) = struct.unpack("!h", raw_bytes[:2]) + return _dispatch_deserializer(obj_type)(raw_bytes[2:]) + + +@singledispatch +def serialize_obj(obj): + raise NotImplementedError("Unsupported type: {}".format(type(obj))) + + +@serialize_obj.register(int) +def _(obj): + return SerdeObjectTypes.INT, struct.pack("!q", obj) + + +@_register_deserializer(SerdeObjectTypes.INT) +def _(raw_bytes): + return struct.unpack("!q", raw_bytes)[0] + + +@serialize_obj.register(float) +def _(obj): + return SerdeObjectTypes.FLOAT, struct.pack("!d", obj) + + +@_register_deserializer(SerdeObjectTypes.FLOAT) +def _(raw_bytes): + return struct.unpack("!d", raw_bytes)[0] + + +@serialize_obj.register(str) +def _(obj): + utf8_str = obj.encode("utf-8") + return SerdeObjectTypes.STRING, struct.pack("!I", len(utf8_str)) + utf8_str + + +@_register_deserializer(SerdeObjectTypes.STRING) +def _(raw_bytes): + length = struct.unpack("!I", raw_bytes[:4])[0] + return raw_bytes[4 : 4 + length].decode("utf-8") + + +if __name__ == "__main__": + print(SafeSerdes.deserialize(SafeSerdes.serialize(1))) + print(SafeSerdes.deserialize(SafeSerdes.serialize(1.0))) + print(SafeSerdes.deserialize(SafeSerdes.serialize("hello"))) diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index d4cd2706be..c2e3189859 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -14,6 +14,7 @@ # limitations under the License. import io import pickle +import logging import struct import typing from typing import Any, List, Tuple, TypeVar, Union @@ -23,6 +24,7 @@ from ..computing import is_table from ..federation._gc import IterationGC +logger = logging.getLogger(__name__) T = TypeVar("T") if typing.TYPE_CHECKING: @@ -308,6 +310,11 @@ def persistent_load(self, pid: Any) -> Any: if isinstance(pid, _ContextPersistentId): return self._ctx + # def load(self): + # out = super().load() + # logger.error(f"unpickled: {out.__class__.__module__}.{out.__class__.__name__}") + # return out + @classmethod def pull( cls,