Skip to content

Commit

Permalink
improve custom serdes
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 19, 2023
1 parent bb33a45 commit 0eae941
Showing 1 changed file with 92 additions and 30 deletions.
122 changes: 92 additions & 30 deletions python/fate/arch/federation/api/_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,94 @@
if typing.TYPE_CHECKING:
from fate.arch.context import Context
from fate.arch.federation.api import Federation
from fate.arch.computing.api import KVTableContext
from fate.arch.computing.api import KVTableContext, KVTable
import torch
import numpy as np


class _TablePersistentId:
def __init__(self, key, table_meta: "TableMeta") -> None:
self.key = key
self.table_meta = table_meta

@staticmethod
def dump(pickler: "TableRemotePersistentPickler", obj: "KVTable") -> Any:
from fate.arch.computing.api import KVTable

assert isinstance(obj, KVTable)
key = pickler._get_next_table_key()
pickler._push_table(obj, key)
return _TablePersistentId(
key=key,
table_meta=TableMeta(
num_partitions=obj.num_partitions,
key_serdes_type=obj.key_serdes_type,
value_serdes_type=obj.value_serdes_type,
partitioner_type=obj.partitioner_type,
),
)

def load(self, unpickler: "TableRemotePersistentUnpickler"):
table = unpickler._federation.pull_table(
self.key, unpickler._tag, [unpickler._party], table_metas=[self.table_meta]
)[0]
return table


class _ContextPersistentId:
def __init__(self, key) -> None:
self.key = key

@staticmethod
def dump(pickler: "TableRemotePersistentPickler", obj: "Context") -> Any:
from fate.arch.context import Context

assert isinstance(obj, Context)
key = f"{pickler._name}__context__"
return _ContextPersistentId(key)

def load(self, unpickler: "TableRemotePersistentUnpickler"):
return unpickler._ctx


class _TorchSafeTensorPersistentId:
def __init__(self, bytes) -> None:
self.bytes = bytes

@staticmethod
def dump(_pickler: "TableRemotePersistentPickler", obj: "torch.Tensor") -> Any:
import torch
import safetensors.torch

assert isinstance(obj, torch.Tensor)
tensor_bytes = safetensors.torch.save({"t": obj})
return _TorchSafeTensorPersistentId(tensor_bytes)

def load(self, _unpickler: "TableRemotePersistentUnpickler"):
import safetensors.torch

return safetensors.torch.load(self.bytes)["t"]


class _TorchTensorPersistentId:
class _NumpySafeTensorPersistentId:
def __init__(self, bytes) -> None:
self.bytes = bytes

@staticmethod
def dump(_pickler: "TableRemotePersistentPickler", obj: "np.ndarray") -> Any:
import numpy as np
import safetensors.numpy

assert isinstance(obj, np.ndarray)
if obj.dtype != np.dtype("object"):
tensor_bytes = safetensors.numpy.save({"n": obj})
return _NumpySafeTensorPersistentId(tensor_bytes)

def load(self, _unpickler: "TableRemotePersistentUnpickler"):
import safetensors.numpy

return safetensors.numpy.load(self.bytes)["n"]


class _FederationBytesCoder:
@staticmethod
Expand Down Expand Up @@ -123,31 +193,22 @@ def _get_next_table_key(self):
return f"{self._name}__table_persistent_{self._table_index}__"

def persistent_id(self, obj: Any) -> Any:
# TODO: use serdes method check instead of isinstance
from fate.arch.context import Context
from fate.arch.computing.api import KVTable
import torch
import numpy as np

if isinstance(obj, KVTable):
key = self._get_next_table_key()
self._push_table(obj, key)
return _TablePersistentId(
key=key,
table_meta=TableMeta(
num_partitions=obj.num_partitions,
key_serdes_type=obj.key_serdes_type,
value_serdes_type=obj.value_serdes_type,
partitioner_type=obj.partitioner_type,
),
)
return _TablePersistentId.dump(self, obj)
if isinstance(obj, Context):
key = f"{self._name}__context__"
return _ContextPersistentId(key)
return _ContextPersistentId.dump(self, obj)

if isinstance(obj, torch.Tensor):
import safetensors.torch
return _TorchSafeTensorPersistentId.dump(self, obj)

tensor_bytes = safetensors.torch.save({"t": obj})
return _TorchTensorPersistentId(tensor_bytes)
if isinstance(obj, np.ndarray) and obj.dtype != np.dtype("object"):
return _NumpySafeTensorPersistentId.dump(self, obj)

def _push_table(self, table, key):
self._federation.push_table(table=table, name=key, tag=self._tag, parties=self._parties)
Expand Down Expand Up @@ -206,10 +267,10 @@ def push(

class TableRemotePersistentUnpickler(pickle.Unpickler):
__ALLOW_CLASSES = {
"builtins": {"slice"},
"torch._utils": {"_rebuild_tensor_v2"},
"torch": {"device", "Size", "int64", "int32", "float64", "float32", "Tensor", "Storage", "dtype"},
"torch": {"device", "Size", "int64", "int32", "float64", "float32", "dtype"},
"collections": {"OrderedDict"},
# we can remove following after we customize the serdes for `DataFrame`
"builtins": {"slice"},
"pandas.core.series": {"Series"},
"pandas.core.internals.managers": {"SingleBlockManager"},
"pandas.core.indexes.base": {"_new_Index", "Index"},
Expand All @@ -236,15 +297,16 @@ def __init__(
super().__init__(f)

def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _TablePersistentId):
table = self._federation.pull_table(pid.key, self._tag, [self._party], table_metas=[pid.table_meta])[0]
return table
if isinstance(pid, _ContextPersistentId):
return self._ctx
if isinstance(pid, _TorchTensorPersistentId):
import safetensors.torch

return safetensors.torch.load(pid.bytes)["t"]
if isinstance(
pid,
(
_TablePersistentId,
_ContextPersistentId,
_TorchSafeTensorPersistentId,
_NumpySafeTensorPersistentId,
),
):
return pid.load(self)

def find_class(self, module, name):
if cfg.safety.serdes.federation.restricted_type == "unrestricted":
Expand Down

0 comments on commit 0eae941

Please sign in to comment.