Skip to content

Commit

Permalink
use safetensor to serialize/deserialize torch tensor
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 19, 2023
1 parent 4c2337b commit 682ec51
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion python/fate/arch/federation/api/_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def __init__(self, key) -> None:
self.key = key


class _TorchTensorPersistentId:
def __init__(self, bytes) -> None:
self.bytes = bytes


class _FederationBytesCoder:
@staticmethod
def encode_base(v: bytes) -> bytes:
Expand Down Expand Up @@ -120,6 +125,7 @@ def _get_next_table_key(self):
def persistent_id(self, obj: Any) -> Any:
from fate.arch.context import Context
from fate.arch.computing.api import KVTable
import torch

if isinstance(obj, KVTable):
key = self._get_next_table_key()
Expand All @@ -137,6 +143,12 @@ def persistent_id(self, obj: Any) -> Any:
key = f"{self._name}__context__"
return _ContextPersistentId(key)

if isinstance(obj, torch.Tensor):
import safetensors.torch

tensor_bytes = safetensors.torch.save({"t": obj})
return _TorchTensorPersistentId(tensor_bytes)

def _push_table(self, table, key):
self._federation.push_table(table=table, name=key, tag=self._tag, parties=self._parties)
self._table_index += 1
Expand Down Expand Up @@ -196,7 +208,7 @@ class TableRemotePersistentUnpickler(pickle.Unpickler):
__ALLOW_CLASSES = {
"builtins": {"slice"},
"torch._utils": {"_rebuild_tensor_v2"},
"torch.storage": {"_load_from_bytes"},
# "torch.storage": {"_load_from_bytes"},
"torch": {"device", "Size", "int64", "int32", "float64", "float32", "Tensor", "Storage", "dtype"},
"collections": {"OrderedDict"},
"pandas.core.series": {"Series"},
Expand Down Expand Up @@ -230,6 +242,10 @@ def persistent_load(self, pid: Any) -> Any:
return table
if isinstance(pid, _ContextPersistentId):
return self._ctx
if isinstance(pid, _TorchTensorPersistentId):
import safetensors.torch

return safetensors.torch.load(pid.bytes)["t"]

def find_class(self, module, name):
if cfg.safety.serdes.federation.restricted_type == "unrestricted":
Expand Down

0 comments on commit 682ec51

Please sign in to comment.