Skip to content

Commit

Permalink
add whitelist to federation
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Dec 18, 2023
1 parent fcc4dd9 commit d328f8e
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 77 deletions.
7 changes: 5 additions & 2 deletions configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
safety:
serdes:
# supported types: unrestricted, restricted, restricted_catch_miss
# supported types: unrestricted, restricted
restricted_type: "unrestricted"

federation:
# supported types: unrestricted, restricted, restricted_catch_miss
restricted_type: "restricted"
restricted_catch_miss_path:
phe:
paillier:
allow: True
Expand Down
1 change: 0 additions & 1 deletion configs/whitelist.yaml

This file was deleted.

11 changes: 5 additions & 6 deletions python/fate/arch/computing/api/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def load(self, uri: URI, schema: dict, options: dict = None):

@_compute_info
def parallelize(
self, data, include_key=True, partition=None, key_serdes_type=0, value_serdes_type=0, partitioner_type=0
self, data, include_key=True, partition=None, key_serdes_type=0, value_serdes_type=0, partitioner_type=0
) -> "KVTable":
key_serdes = get_serdes_by_type(key_serdes_type)
value_serdes = get_serdes_by_type(value_serdes_type)
Expand Down Expand Up @@ -299,11 +299,10 @@ def _map_reduce_partitions_with_index(
def mapPartitionsWithIndexNoSerdes(
self,
map_partition_op: Callable[[int, Iterable[Tuple[bytes, bytes]]], Iterable[Tuple[bytes, bytes]]],
shuffle=False,
output_key_serdes_type=None,
output_value_serdes_type=None,
output_partitioner_type=None,

shuffle=False,
output_key_serdes_type=None,
output_value_serdes_type=None,
output_partitioner_type=None,
):
"""
caller should guarantee that the output of map_partition_op is a generator of (bytes, bytes)
Expand Down
18 changes: 9 additions & 9 deletions python/fate/arch/computing/backends/spark/_csession.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def _load(self, uri: URI, schema, options: dict = None) -> "Table":
raise NotImplementedError(f"uri type {uri} not supported with spark backend")

def _parallelize(
self,
data: Iterable,
total_partitions,
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
self,
data: Iterable,
total_partitions,
key_serdes,
key_serdes_type,
value_serdes,
value_serdes_type,
partitioner,
partitioner_type,
):
# noinspection PyPackageRequirements
from pyspark import SparkContext
Expand Down
1 change: 0 additions & 1 deletion python/fate/arch/computing/backends/spark/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(self, rdd: pyspark.RDD, key_serdes_type, value_serdes_type, partiti
num_partitions=rdd.getNumPartitions(),
)


@property
def engine(self):
return self._engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _parallelize(
return Table(table)

def _info(self, level=0):

if level == 0:
return f"Standalone<session_id={self.session_id}, max_workers={self._session.max_workers}, data_dir={self._session.data_dir}>"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,6 @@ def __init__(self, data_dir: str, session_id, party: Tuple[str, str]) -> None:
def wait_status_set(self, key: bytes) -> bytes:
value = self.get_status(key)
while value is None:

time.sleep(0.1)
value = self.get_status(key)
return key
Expand Down
1 change: 0 additions & 1 deletion python/fate/arch/computing/backends/standalone/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
LOGGER = logging.getLogger(__name__)



class Table(KVTable):
def __init__(self, table: StandaloneTable):
self._table = table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ def mmh3_partitioner(key: bytes, total_partitions):
import mmh3

return mmh3.hash(key) % total_partitions

10 changes: 0 additions & 10 deletions python/fate/arch/computing/serdes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ def get_serdes_by_type(serdes_type: int):
from ._restricted_serdes import get_restricted_serdes

return get_restricted_serdes()
elif cfg.safety.serdes.restricted_type == "restricted_catch_miss":
from ._restricted_caught_miss_serdes import get_restricted_catch_miss_serdes

return get_restricted_catch_miss_serdes()
else:
raise ValueError(f"restricted type `{cfg.safety.serdes.restricted_type}` not supported")
elif serdes_type == 1:
Expand All @@ -23,9 +19,3 @@ def get_serdes_by_type(serdes_type: int):
return get_integer_serdes()
else:
raise ValueError(f"serdes type `{serdes_type}` not supported")


def dump_miss(path):
from ._restricted_caught_miss_serdes import dump_miss

dump_miss(path)

This file was deleted.

40 changes: 40 additions & 0 deletions python/fate/arch/federation/api/_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import typing
from typing import Any, List, Tuple, TypeVar

from fate.arch.config import cfg
from ._table_meta import TableMeta
from ._type import PartyMeta

Expand Down Expand Up @@ -192,6 +193,21 @@ def push(


class TableRemotePersistentUnpickler(pickle.Unpickler):
__ALLOW_CLASSES = {
"buildins": {"slice"},
"torch._utils": {"_rebuild_tensor_v2"},
"torch.storage": {"_load_from_bytes"},
"torch": {"device", "Size", "int64", "int32", "float64", "float32", "Tensor", "Storage", "dtype"},
"collections": {"OrderedDict"},
"pandas.core.series": {"Series"},
"pandas.core.internals.managers": {"SingleBlockManager"},
"pandas.core.indexes.base": {"_new_Index", "Index"},
"numpy.core.multiarray": {"_reconstruct"},
"numpy": {"ndarray", "dtype"},
}
__BUILDIN_MODULES = {"fate.", "fate_utils."}
__ALLOW_MODULES = {}

def __init__(
self,
ctx: "Context",
Expand All @@ -215,6 +231,30 @@ def persistent_load(self, pid: Any) -> Any:
if isinstance(pid, _ContextPersistentId):
return self._ctx

def find_class(self, module, name):
if cfg.safety.serdes.federation.restricted_type == "unrestricted":
return super().find_class(module, name)
for m in self.__BUILDIN_MODULES:
if module.startswith(m):
return super().find_class(module, name)
if module in self.__ALLOW_MODULES:
return super().find_class(module, name)
elif module in self.__ALLOW_CLASSES and name in self.__ALLOW_CLASSES[module]:
return super().find_class(module, name)
else:
if cfg.safety.serdes.federation.restricted_type == "restricted_catch_miss":
self.__ALLOW_CLASSES.setdefault(module, set()).add(name)
path_to_write = f"{cfg.safety.serdes.federation.restricted_catch_miss_path}_{self._federation.local_party[0]}_{self._federation.local_party[1]}"
with open(path_to_write, "a") as f:
f.write(f"{module}.{name}\n")
return super().find_class(module, name)
elif cfg.safety.serdes.federation.restricted_type == "restricted":
raise ValueError(
f"Deserialization is restricted for class `{module}`.`{name}`, allowlist: {self.__ALLOW_CLASSES}"
)
else:
raise ValueError(f"invalid restricted_type: {cfg.safety.serdes.federation.restricted_type}")

@classmethod
def pull(
cls,
Expand Down

0 comments on commit d328f8e

Please sign in to comment.