Skip to content

Commit

Permalink
Merge pull request #5311 from FederatedAI/feature-2.0.0-rc-fix-config
Browse files Browse the repository at this point in the history
feature 2.0.0 rc fix config
  • Loading branch information
mgqa34 authored Dec 8, 2023
2 parents ac0bef9 + 35c091d commit a72427a
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 0 deletions.
5 changes: 5 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ nn:
protocol: "layer_estimation"
skip_loss_forward: True
cache_pred_size: True

safety:
serdes:
# supported types: unrestricted, restricted, restricted_catch_miss
restricted_type: "unrestricted"
1 change: 1 addition & 0 deletions configs/whitelist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fate: "*"
83 changes: 83 additions & 0 deletions python/fate/arch/computing/serdes/_safe_serdes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import enum
import struct
from functools import singledispatch


class SerdeObjectTypes(enum.IntEnum):
INT = 0
FLOAT = 1
STRING = 2
BYTES = 3
LIST = 4
DICT = 5
TUPLE = 6


_deserializer_registry = {}


def _register_deserializer(obj_type_enum):
def _register(deserializer_func):
_deserializer_registry[obj_type_enum] = deserializer_func
return deserializer_func

return _register


def _dispatch_deserializer(obj_type_enum):
return _deserializer_registry[obj_type_enum]


class SafeSerdes(object):
@staticmethod
def serialize(obj):
obj_type, obj_bytes = serialize_obj(obj)
return struct.pack("!h", obj_type) + obj_bytes

@staticmethod
def deserialize(raw_bytes):
(obj_type,) = struct.unpack("!h", raw_bytes[:2])
return _dispatch_deserializer(obj_type)(raw_bytes[2:])


@singledispatch
def serialize_obj(obj):
raise NotImplementedError("Unsupported type: {}".format(type(obj)))


@serialize_obj.register(int)
def _(obj):
return SerdeObjectTypes.INT, struct.pack("!q", obj)


@_register_deserializer(SerdeObjectTypes.INT)
def _(raw_bytes):
return struct.unpack("!q", raw_bytes)[0]


@serialize_obj.register(float)
def _(obj):
return SerdeObjectTypes.FLOAT, struct.pack("!d", obj)


@_register_deserializer(SerdeObjectTypes.FLOAT)
def _(raw_bytes):
return struct.unpack("!d", raw_bytes)[0]


@serialize_obj.register(str)
def _(obj):
utf8_str = obj.encode("utf-8")
return SerdeObjectTypes.STRING, struct.pack("!I", len(utf8_str)) + utf8_str


@_register_deserializer(SerdeObjectTypes.STRING)
def _(raw_bytes):
length = struct.unpack("!I", raw_bytes[:4])[0]
return raw_bytes[4 : 4 + length].decode("utf-8")


if __name__ == "__main__":
print(SafeSerdes.deserialize(SafeSerdes.serialize(1)))
print(SafeSerdes.deserialize(SafeSerdes.serialize(1.0)))
print(SafeSerdes.deserialize(SafeSerdes.serialize("hello")))
7 changes: 7 additions & 0 deletions python/fate/arch/context/_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import io
import pickle
import logging
import struct
import typing
from typing import Any, List, Tuple, TypeVar, Union
Expand All @@ -23,6 +24,7 @@
from ..computing import is_table
from ..federation._gc import IterationGC

logger = logging.getLogger(__name__)
T = TypeVar("T")

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -308,6 +310,11 @@ def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _ContextPersistentId):
return self._ctx

# def load(self):
# out = super().load()
# logger.error(f"unpickled: {out.__class__.__module__}.{out.__class__.__name__}")
# return out

@classmethod
def pull(
cls,
Expand Down

0 comments on commit a72427a

Please sign in to comment.