diff --git a/python/fate/arch/_standalone.py b/python/fate/arch/_standalone.py index 9fdcb585ae..a2f1467673 100644 --- a/python/fate/arch/_standalone.py +++ b/python/fate/arch/_standalone.py @@ -34,10 +34,9 @@ import lmdb import numpy as np -from fate_arch.common import Party, file_utils -from fate_arch.common.log import getLogger -from fate_arch.federation import FederationDataType - +from .common import Party, file_utils +from .common.log import getLogger +from .federation import FederationDataType LOGGER = getLogger() @@ -341,9 +340,7 @@ def get(self, k): with self._get_env_for_partition(p) as env: with env.begin(write=True) as txn: old_value_bytes = txn.get(k_bytes) - return ( - None if old_value_bytes is None else deserialize(old_value_bytes) - ) + return None if old_value_bytes is None else deserialize(old_value_bytes) def delete(self, k): k_bytes = _k_to_bytes(k=k) @@ -487,7 +484,10 @@ def _get_splits(obj, max_message_size): return obj, num_slice else: _max_size = max_message_size - kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)] + kv = [ + (i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) + for i in range(num_slice) + ] return kv, num_slice @@ -667,10 +667,13 @@ def get(self, name: str, tag: str, parties: typing.List[Party]) -> typing.List: dtype = r[2] LOGGER.debug( f"[{log_str}] got " - f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}") + f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}" + ) if dtype == FederationDataType.SPLIT_OBJECT: - obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0]))) + obj_bytes = b"".join( + map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])) + ) obj = deserialize(obj_bytes) rtn.append(obj) else: diff --git a/python/fate/arch/abc/__init__.py b/python/fate/arch/abc/__init__.py index 404b6035b9..bd826e72ab 100644 --- a/python/fate/arch/abc/__init__.py +++ b/python/fate/arch/abc/__init__.py @@ -1,7 +1,6 @@ - -from fate_arch.abc._gc import GarbageCollectionABC -from fate_arch.abc._address import AddressABC -from fate_arch.abc._computing import CTableABC, CSessionABC -from fate_arch.abc._storage import StorageTableABC, StorageSessionABC, StorageTableMetaABC -from fate_arch.abc._federation import FederationABC -from fate_arch.abc._components import Components, ComponentMeta +from ._address import AddressABC +from ._components import ComponentMeta, Components +from ._computing import CSessionABC, CTableABC +from ._federation import FederationABC +from ._gc import GarbageCollectionABC +from ._storage import StorageSessionABC, StorageTableABC, StorageTableMetaABC diff --git a/python/fate/arch/abc/_computing.py b/python/fate/arch/abc/_computing.py index 25ca929fc5..4f9047f0ce 100644 --- a/python/fate/arch/abc/_computing.py +++ b/python/fate/arch/abc/_computing.py @@ -23,8 +23,8 @@ from abc import ABCMeta from collections import Iterable -from fate_arch.abc._address import AddressABC -from fate_arch.abc._path import PathABC +from ._address import AddressABC +from ._path import PathABC __all__ = ["CTableABC", "CSessionABC"] @@ -148,7 +148,7 @@ def count(self) -> int: ... @abc.abstractmethod - def map(self, func) -> 'CTableABC': + def map(self, func) -> "CTableABC": """ apply `func` to each data @@ -164,7 +164,7 @@ def map(self, func) -> 'CTableABC': Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([('k1', 1), ('k2', 2), ('k3', 3)], include_key=True, partition=2) >>> b = a.map(lambda k, v: (k, v**2)) >>> list(b.collect()) @@ -189,7 +189,7 @@ def mapValues(self, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([('a', ['apple', 'banana', 'lemon']), ('b', ['grapes'])], include_key=True, partition=2) >>> b = a.mapValues(lambda x: len(x)) >>> list(b.collect()) @@ -198,7 +198,9 @@ def mapValues(self, func): ... @abc.abstractmethod - def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False): + def mapPartitions( + self, func, use_previous_behavior=True, preserves_partitioning=False + ): """ apply ``func`` to each partition of table @@ -218,7 +220,7 @@ def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([1, 2, 3, 4, 5], include_key=False, partition=2) >>> def f(iterator): ... s = 0 @@ -251,7 +253,7 @@ def mapReducePartitions(self, mapper, reducer, **kwargs): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> table = computing_session.parallelize([(1, 2), (2, 3), (3, 4), (4, 5)], include_key=False, partition=2) >>> def _mapper(it): ... r = [] @@ -286,7 +288,7 @@ def applyPartitions(self, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([1, 2, 3], partition=3, include_key=False) >>> def f(it): ... r = [] @@ -319,7 +321,7 @@ def flatMap(self, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([(1, 1), (2, 2)], include_key=True, partition=2) >>> b = a.flatMap(lambda x, y: [(x, y), (x + 10, y ** 2)]) >>> c = list(b.collect()) @@ -347,7 +349,7 @@ def reduce(self, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize(range(100), include_key=False, partition=4) >>> assert a.reduce(lambda x, y: x + y) == sum(range(100)) @@ -370,7 +372,7 @@ def glom(self): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize(range(5), include_key=False, partition=3).glom().collect() >>> list(a) [(2, [(2, 2)]), (3, [(0, 0), (3, 3)]), (4, [(1, 1), (4, 4)])] @@ -378,7 +380,13 @@ def glom(self): ... @abc.abstractmethod - def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None): + def sample( + self, + *, + fraction: typing.Optional[float] = None, + num: typing.Optional[int] = None, + seed=None + ): """ return a sampled subset of this Table. Parameters @@ -399,7 +407,7 @@ def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optiona Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> x = computing_session.parallelize(range(100), include_key=False, partition=4) >>> 6 <= x.sample(fraction=0.1, seed=81).count() <= 14 True @@ -428,7 +436,7 @@ def filter(self, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([0, 1, 2], include_key=False, partition=2) >>> b = a.filter(lambda k, v : k % 2 == 0) >>> list(b.collect()) @@ -461,7 +469,7 @@ def join(self, other, func): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([1, 2, 3], include_key=False, partition=2) # [(0, 1), (1, 2), (2, 3)] >>> b = computing_session.parallelize([(1, 1), (2, 2), (3, 3)], include_key=True, partition=2) >>> c = a.join(b, lambda v1, v2 : v1 + v2) @@ -492,7 +500,7 @@ def union(self, other, func=lambda v1, v2: v1): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize([1, 2, 3], include_key=False, partition=2) # [(0, 1), (1, 2), (2, 3)] >>> b = computing_session.parallelize([(1, 1), (2, 2), (3, 3)], include_key=True, partition=2) >>> c = a.union(b, lambda v1, v2 : v1 + v2) @@ -518,7 +526,7 @@ def subtractByKey(self, other): Examples -------- - >>> from fate_arch.session import computing_session + >>> from fate.arch.session import computing_session >>> a = computing_session.parallelize(range(10), include_key=False, partition=2) >>> b = computing_session.parallelize(range(5), include_key=False, partition=2) >>> c = a.subtractByKey(b) @@ -544,7 +552,9 @@ class CSessionABC(metaclass=ABCMeta): """ @abc.abstractmethod - def load(self, address: AddressABC, partitions, schema: dict, **kwargs) -> typing.Union[PathABC, CTableABC]: + def load( + self, address: AddressABC, partitions, schema: dict, **kwargs + ) -> typing.Union[PathABC, CTableABC]: """ load a table from given address @@ -565,7 +575,9 @@ def load(self, address: AddressABC, partitions, schema: dict, **kwargs) -> typin ... @abc.abstractmethod - def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs) -> CTableABC: + def parallelize( + self, data: Iterable, partition: int, include_key: bool, **kwargs + ) -> CTableABC: """ create table from iterable data diff --git a/python/fate/arch/abc/_federation.py b/python/fate/arch/abc/_federation.py index b3e8fd21a9..b0bf22eee0 100644 --- a/python/fate/arch/abc/_federation.py +++ b/python/fate/arch/abc/_federation.py @@ -2,8 +2,8 @@ import typing from abc import ABCMeta -from fate_arch.abc._gc import GarbageCollectionABC -from fate_arch.common import Party +from ..common import Party +from ._gc import GarbageCollectionABC __all__ = ["FederationABC"] @@ -19,10 +19,9 @@ def session_id(self) -> str: ... @abc.abstractmethod - def get(self, name: str, - tag: str, - parties: typing.List[Party], - gc: GarbageCollectionABC) -> typing.List: + def get( + self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC + ) -> typing.List: """ get objects/tables from ``parties`` @@ -46,11 +45,14 @@ def get(self, name: str, ... @abc.abstractmethod - def remote(self, v, - name: str, - tag: str, - parties: typing.List[Party], - gc: GarbageCollectionABC): + def remote( + self, + v, + name: str, + tag: str, + parties: typing.List[Party], + gc: GarbageCollectionABC, + ): """ remote object/table to ``parties`` diff --git a/python/fate/arch/abc/_gc.py b/python/fate/arch/abc/_gc.py index bc08850f55..1dfabed890 100644 --- a/python/fate/arch/abc/_gc.py +++ b/python/fate/arch/abc/_gc.py @@ -2,6 +2,5 @@ class GarbageCollectionABC(metaclass=abc.ABCMeta): - def add_gc_action(self, tag: str, obj, method, args_dict): ... diff --git a/python/fate/arch/abc/_storage.py b/python/fate/arch/abc/_storage.py index 32fd0c4da0..5cb55fa331 100644 --- a/python/fate/arch/abc/_storage.py +++ b/python/fate/arch/abc/_storage.py @@ -18,10 +18,6 @@ import abc from typing import Iterable -from fate_arch.common.log import getLogger - -LOGGER = getLogger() - class StorageTableMetaABC(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -37,7 +33,15 @@ def query_table_meta(self, filter_fields, query_fields=None): ... @abc.abstractmethod - def update_metas(self, schema=None, count=None, part_of_data=None, description=None, partitions=None, **kwargs): + def update_metas( + self, + schema=None, + count=None, + part_of_data=None, + description=None, + partitions=None, + **kwargs + ): ... @abc.abstractmethod @@ -169,13 +173,15 @@ def meta(self, meta: StorageTableMetaABC): ... @abc.abstractmethod - def update_meta(self, - schema=None, - count=None, - part_of_data=None, - description=None, - partitions=None, - **kwargs) -> StorageTableMetaABC: + def update_meta( + self, + schema=None, + count=None, + part_of_data=None, + description=None, + partitions=None, + **kwargs + ) -> StorageTableMetaABC: ... @abc.abstractmethod @@ -209,8 +215,16 @@ def check_address(self): class StorageSessionABC(metaclass=abc.ABCMeta): @abc.abstractmethod - def create_table(self, address, name, namespace, partitions, storage_type=None, options=None, - **kwargs) -> StorageTableABC: + def create_table( + self, + address, + name, + namespace, + partitions, + storage_type=None, + options=None, + **kwargs + ) -> StorageTableABC: ... @abc.abstractmethod diff --git a/python/fate/arch/common/__init__.py b/python/fate/arch/common/__init__.py index 2dc2853eb5..cbd5efb30d 100644 --- a/python/fate/arch/common/__init__.py +++ b/python/fate/arch/common/__init__.py @@ -1,3 +1,10 @@ -from fate_arch.common._types import FederatedMode, FederatedCommunicationType, EngineType, CoordinationProxyService, \ - CoordinationCommunicationProtocol -from fate_arch.common._types import BaseType, Party, DTable +from ._types import ( + BaseType, + CoordinationCommunicationProtocol, + CoordinationProxyService, + DTable, + EngineType, + FederatedCommunicationType, + FederatedMode, + Party, +) diff --git a/python/fate/arch/common/_parties.py b/python/fate/arch/common/_parties.py index 24518947fe..5627c0f392 100644 --- a/python/fate/arch/common/_parties.py +++ b/python/fate/arch/common/_parties.py @@ -17,7 +17,7 @@ import typing -from fate_arch.common import Party +from ._types import Party class Role: @@ -95,9 +95,9 @@ def from_conf(conf: typing.MutableMapping[str, dict]): return PartiesInfo(local, role_to_parties) def __init__( - self, - local: Party, - role_to_parties: typing.MutableMapping[str, typing.List[Party]], + self, + local: Party, + role_to_parties: typing.MutableMapping[str, typing.List[Party]], ): self._local = local self._role_to_parties = role_to_parties diff --git a/python/fate/arch/common/_types.py b/python/fate/arch/common/_types.py index dc1c400e07..21f95cf3fe 100644 --- a/python/fate/arch/common/_types.py +++ b/python/fate/arch/common/_types.py @@ -1,5 +1,3 @@ - - class EngineType(object): COMPUTING = "computing" STORAGE = "storage" @@ -58,6 +56,7 @@ def _dict(obj): else: data = obj return {"type": obj.__class__.__name__, "data": data, "module": module} + return _dict(self) @@ -85,6 +84,9 @@ def __lt__(self, other): def __eq__(self, other): return self.party_id == other.party_id and self.role == other.role + def as_tuple(self): + return (self.role, self.party_id) + class DTable(BaseType): def __init__(self, namespace, name, partitions=None): diff --git a/python/fate/arch/common/address.py b/python/fate/arch/common/address.py index f290f610f7..f7567f0fe3 100644 --- a/python/fate/arch/common/address.py +++ b/python/fate/arch/common/address.py @@ -1,5 +1,5 @@ -from fate_arch.abc import AddressABC -from fate_arch.metastore.db_utils import StorageConnector +from ..abc import AddressABC +from ..metastore.db_utils import StorageConnector class AddressBase(AddressABC): @@ -22,7 +22,14 @@ def storage_engine(self): class StandaloneAddress(AddressBase): - def __init__(self, home=None, name=None, namespace=None, storage_type=None, connector_name=None): + def __init__( + self, + home=None, + name=None, + namespace=None, + storage_type=None, + connector_name=None, + ): self.home = home self.name = name self.namespace = namespace @@ -100,7 +107,9 @@ def __repr__(self): class ApiAddress(AddressBase): - def __init__(self, method="POST", url=None, header=None, body=None, connector_name=None): + def __init__( + self, method="POST", url=None, header=None, body=None, connector_name=None + ): self.method = method self.url = url self.header = header if header else {} @@ -118,7 +127,16 @@ def __repr__(self): class MysqlAddress(AddressBase): - def __init__(self, user=None, passwd=None, host=None, port=None, db=None, name=None, connector_name=None): + def __init__( + self, + user=None, + passwd=None, + host=None, + port=None, + db=None, + name=None, + connector_name=None, + ): self.user = user self.passwd = passwd self.host = host @@ -139,12 +157,27 @@ def __repr__(self): @property def connector(self): - return {"user": self.user, "passwd": self.passwd, "host": self.host, "port": self.port, "db": self.db} + return { + "user": self.user, + "passwd": self.passwd, + "host": self.host, + "port": self.port, + "db": self.db, + } class HiveAddress(AddressBase): - def __init__(self, host=None, name=None, port=10000, username=None, database='default', auth_mechanism='PLAIN', - password=None, connector_name=None): + def __init__( + self, + host=None, + name=None, + port=10000, + username=None, + database="default", + auth_mechanism="PLAIN", + password=None, + connector_name=None, + ): self.host = host self.username = username self.port = port @@ -171,12 +204,24 @@ def connector(self): "username": self.username, "password": self.password, "auth_mechanism": self.auth_mechanism, - "database": self.database} + "database": self.database, + } class LinkisHiveAddress(AddressBase): - def __init__(self, host="127.0.0.1", port=9001, username='', database='', name='', run_type='hql', - execute_application_name='hive', source={}, params={}, connector_name=None): + def __init__( + self, + host="127.0.0.1", + port=9001, + username="", + database="", + name="", + run_type="hql", + execute_application_name="hive", + source={}, + params={}, + connector_name=None, + ): self.host = host self.port = port self.username = username diff --git a/python/fate/arch/common/base_utils.py b/python/fate/arch/common/base_utils.py index d8aa6606ef..a5d3775d74 100644 --- a/python/fate/arch/common/base_utils.py +++ b/python/fate/arch/common/base_utils.py @@ -24,11 +24,10 @@ import uuid from enum import Enum, IntEnum -from fate_arch.common.conf_utils import get_base_config -from fate_arch.common import BaseType +from ._types import BaseType +from .conf_utils import get_base_config - -use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) +use_deserialize_safe_module = get_base_config("use_deserialize_safe_module", False) class CustomJSONEncoder(json.JSONEncoder): @@ -38,9 +37,9 @@ def __init__(self, **kwargs): def default(self, obj): if isinstance(obj, datetime.datetime): - return obj.strftime('%Y-%m-%d %H:%M:%S') + return obj.strftime("%Y-%m-%d %H:%M:%S") elif isinstance(obj, datetime.date): - return obj.strftime('%Y-%m-%d') + return obj.strftime("%Y-%m-%d") elif isinstance(obj, datetime.timedelta): return str(obj) elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): @@ -117,22 +116,18 @@ def deserialize_b64(src): return pickle.loads(src) -safe_module = { - 'federatedml', - 'numpy', - 'fate_flow' -} +safe_module = {"federatedml", "numpy", "fate_flow"} class RestrictedUnpickler(pickle.Unpickler): def find_class(self, module, name): import importlib - if module.split('.')[0] in safe_module: + + if module.split(".")[0] in safe_module: _module = importlib.import_module(module) return getattr(_module, name) # Forbid everything else. - raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) + raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) def restricted_loads(src): @@ -148,7 +143,12 @@ def get_lan_ip(): def get_interface_ip(ifname): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) return socket.inet_ntoa( - fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) + fcntl.ioctl( + s.fileno(), + 0x8915, + struct.pack("256s", string_to_bytes(ifname[:15])), + )[20:24] + ) ip = socket.gethostbyname(socket.getfqdn()) if ip.startswith("127.") and os.name != "nt": @@ -170,4 +170,4 @@ def get_interface_ip(ifname): break except IOError as e: pass - return ip or '' + return ip or "" diff --git a/python/fate/arch/common/conf_utils.py b/python/fate/arch/common/conf_utils.py index 916f4666e6..9894a8a9e9 100644 --- a/python/fate/arch/common/conf_utils.py +++ b/python/fate/arch/common/conf_utils.py @@ -14,10 +14,11 @@ # limitations under the License. # import os -from filelock import FileLock from importlib import import_module -from fate_arch.common import file_utils +from filelock import FileLock + +from .file_utils import get_project_base_directory, load_yaml_conf, rewrite_yaml_conf SERVICE_CONF = "service_conf.yaml" TRANSFER_CONF = "transfer_conf.yaml" @@ -25,15 +26,15 @@ def conf_realpath(conf_name): conf_path = f"conf/{conf_name}" - return os.path.join(file_utils.get_project_base_directory(), conf_path) + return os.path.join(get_project_base_directory(), conf_path) def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: local_config = {} - local_path = conf_realpath(f'local.{conf_name}') + local_path = conf_realpath(f"local.{conf_name}") if os.path.exists(local_path): - local_config = file_utils.load_yaml_conf(local_path) + local_config = load_yaml_conf(local_path) if not isinstance(local_config, dict): raise ValueError(f'Invalid config file: "{local_path}".') @@ -41,7 +42,7 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: return local_config[key] config_path = conf_realpath(conf_name) - config = file_utils.load_yaml_conf(config_path) + config = load_yaml_conf(config_path) if not isinstance(config, dict): raise ValueError(f'Invalid config file: "{config_path}".') @@ -78,9 +79,9 @@ def decrypt_database_config(database=None, passwd_key="passwd"): def update_config(key, value, conf_name=SERVICE_CONF): conf_path = conf_realpath(conf_name=conf_name) if not os.path.isabs(conf_path): - conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) + conf_path = os.path.join(get_project_base_directory(), conf_path) with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): - config = file_utils.load_yaml_conf(conf_path=conf_path) or {} + config = load_yaml_conf(conf_path=conf_path) or {} config[key] = value - file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) + rewrite_yaml_conf(conf_path=conf_path, config=config) diff --git a/python/fate/arch/common/data_utils.py b/python/fate/arch/common/data_utils.py index 64c173362c..a5748d2989 100644 --- a/python/fate/arch/common/data_utils.py +++ b/python/fate/arch/common/data_utils.py @@ -1,30 +1,40 @@ import os import uuid -from fate_arch.common import file_utils -from fate_arch.storage import StorageEngine +from ..storage import StorageEngine +from .file_utils import get_project_base_directory def default_output_info(task_id, task_version, output_type): return f"output_{output_type}_{task_id}_{task_version}", uuid.uuid1().hex -def default_input_fs_path(name, namespace, prefix=None, storage_engine=StorageEngine.HDFS): +def default_input_fs_path( + name, namespace, prefix=None, storage_engine=StorageEngine.HDFS +): if storage_engine == StorageEngine.HDFS: - return default_hdfs_path(data_type="input", name=name, namespace=namespace, prefix=prefix) + return default_hdfs_path( + data_type="input", name=name, namespace=namespace, prefix=prefix + ) elif storage_engine == StorageEngine.LOCALFS: return default_localfs_path(data_type="input", name=name, namespace=namespace) -def default_output_fs_path(name, namespace, prefix=None, storage_engine=StorageEngine.HDFS): +def default_output_fs_path( + name, namespace, prefix=None, storage_engine=StorageEngine.HDFS +): if storage_engine == StorageEngine.HDFS: - return default_hdfs_path(data_type="output", name=name, namespace=namespace, prefix=prefix) + return default_hdfs_path( + data_type="output", name=name, namespace=namespace, prefix=prefix + ) elif storage_engine == StorageEngine.LOCALFS: return default_localfs_path(data_type="output", name=name, namespace=namespace) def default_localfs_path(name, namespace, data_type): - return os.path.join(file_utils.get_project_base_directory(), 'localfs', data_type, namespace, name) + return os.path.join( + get_project_base_directory(), "localfs", data_type, namespace, name + ) def default_hdfs_path(data_type, name, namespace, prefix=None): diff --git a/python/fate/arch/common/engine_utils.py b/python/fate/arch/common/engine_utils.py index 08d6932dd6..ff846f4301 100644 --- a/python/fate/arch/common/engine_utils.py +++ b/python/fate/arch/common/engine_utils.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import typing - -from fate_arch.common import FederatedMode, conf_utils -from fate_arch.computing import ComputingEngine -from fate_arch.federation import FederationEngine -from fate_arch.storage import StorageEngine -from fate_arch.relation_ship import Relationship -from fate_arch.common import EngineType +from ..common import EngineType, FederatedMode, conf_utils +from ..computing import ComputingEngine +from ..federation import FederationEngine +from ..relation_ship import Relationship +from ..storage import StorageEngine def get_engine_class_members(engine_class) -> list: @@ -49,8 +46,10 @@ def get_engines(): # computing engine if default_engines.get(EngineType.COMPUTING) is None: - raise RuntimeError(f"{EngineType.COMPUTING} is None," - f"Please check default_engines on conf/service_conf.yaml") + raise RuntimeError( + f"{EngineType.COMPUTING} is None," + f"Please check default_engines on conf/service_conf.yaml" + ) engines[EngineType.COMPUTING] = default_engines[EngineType.COMPUTING].upper() if engines[EngineType.COMPUTING] not in get_engine_class_members(ComputingEngine): raise RuntimeError(f"{engines[EngineType.COMPUTING]} is illegal") @@ -67,7 +66,9 @@ def get_engines(): for t in (EngineType.STORAGE, EngineType.FEDERATION): if engines.get(t) is None: # use default relation engine - engines[t] = Relationship.Computing[engines[EngineType.COMPUTING]][t]["default"] + engines[t] = Relationship.Computing[engines[EngineType.COMPUTING]][t][ + "default" + ] # set default federated mode by federation engine if engines[EngineType.FEDERATION] == FederationEngine.STANDALONE: @@ -82,26 +83,39 @@ def get_engines(): raise RuntimeError(f"{engines[EngineType.FEDERATION]} is illegal") for t in [EngineType.FEDERATION]: - if engines[t] not in Relationship.Computing[engines[EngineType.COMPUTING]][t]["support"]: - raise RuntimeError(f"{engines[t]} is not supported in {engines[EngineType.COMPUTING]}") + if ( + engines[t] + not in Relationship.Computing[engines[EngineType.COMPUTING]][t]["support"] + ): + raise RuntimeError( + f"{engines[t]} is not supported in {engines[EngineType.COMPUTING]}" + ) return engines def is_standalone(): - return get_engines().get(EngineType.FEDERATION).upper() == FederationEngine.STANDALONE + return ( + get_engines().get(EngineType.FEDERATION).upper() == FederationEngine.STANDALONE + ) def get_engines_config_from_conf(group_map=False): engines_config = {} engine_group_map = {} - for engine_type in {EngineType.COMPUTING, EngineType.FEDERATION, EngineType.STORAGE}: + for engine_type in { + EngineType.COMPUTING, + EngineType.FEDERATION, + EngineType.STORAGE, + }: engines_config[engine_type] = {} engine_group_map[engine_type] = {} for group_name, engine_map in Relationship.EngineConfMap.items(): for engine_type, name_maps in engine_map.items(): for name_map in name_maps: - single_engine_config = conf_utils.get_base_config(group_name, {}).get(name_map[1], {}) + single_engine_config = conf_utils.get_base_config(group_name, {}).get( + name_map[1], {} + ) if single_engine_config: engine_name = name_map[0] engines_config[engine_type][engine_name] = single_engine_config diff --git a/python/fate/arch/common/file_utils.py b/python/fate/arch/common/file_utils.py index 64ab0d20de..0c63e9c1ac 100644 --- a/python/fate/arch/common/file_utils.py +++ b/python/fate/arch/common/file_utils.py @@ -70,7 +70,9 @@ def get_fate_python_directory(*args): def get_federatedml_setting_conf_directory(): - return os.path.join(get_fate_python_directory(), 'federatedml', 'conf', 'setting_conf') + return os.path.join( + get_fate_python_directory(), "federatedml", "conf", "setting_conf" + ) @cached(cache=LRUCache(maxsize=10)) diff --git a/python/fate/arch/common/hdfs_utils.py b/python/fate/arch/common/hdfs_utils.py index e2e49711b5..1eb1d5cb43 100644 --- a/python/fate/arch/common/hdfs_utils.py +++ b/python/fate/arch/common/hdfs_utils.py @@ -16,8 +16,8 @@ import pickle -_DELIMITER = '\t' -NEWLINE = '\n' +_DELIMITER = "\t" +NEWLINE = "\n" def deserialize(m): diff --git a/python/fate/arch/common/hive_utils.py b/python/fate/arch/common/hive_utils.py index fd204e3e97..22aededcf7 100644 --- a/python/fate/arch/common/hive_utils.py +++ b/python/fate/arch/common/hive_utils.py @@ -15,10 +15,11 @@ # import pickle + from pyspark.sql import Row -_DELIMITER = ',' -NEWLINE = '\n' +_DELIMITER = "," +NEWLINE = "\n" def deserialize_line(line): @@ -26,12 +27,12 @@ def deserialize_line(line): def serialize_line(k, v): - return f'{_DELIMITER}'.join([k, pickle.dumps(v).hex()]) + f"{NEWLINE}" + return f"{_DELIMITER}".join([k, pickle.dumps(v).hex()]) + f"{NEWLINE}" def read_line(line_data): line = [str(i) for i in line_data] - return f'{_DELIMITER}'.join(line) + f"{NEWLINE}" + return f"{_DELIMITER}".join(line) + f"{NEWLINE}" def from_row(r): diff --git a/python/fate/arch/common/log.py b/python/fate/arch/common/log.py index 3e2a20c58c..76a4d7cb24 100644 --- a/python/fate/arch/common/log.py +++ b/python/fate/arch/common/log.py @@ -15,13 +15,13 @@ # import inspect -import traceback import logging import os +import traceback from logging.handlers import TimedRotatingFileHandler from threading import RLock -from fate_arch.common import file_utils +from .file_utils import get_project_base_directory class LoggerFactory(object): @@ -50,14 +50,16 @@ class LoggerFactory(object): schedule_logger_dict = {} @staticmethod - def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): + def set_directory( + directory=None, parent_log_dir=None, append_to_parent_log=None, force=False + ): if parent_log_dir: LoggerFactory.PARENT_LOG_DIR = parent_log_dir if append_to_parent_log: LoggerFactory.append_to_parent_log = append_to_parent_log with LoggerFactory.lock: if not directory: - directory = file_utils.get_project_base_directory("logs") + directory = get_project_base_directory("logs") if not LoggerFactory.LOG_DIR or force: LoggerFactory.LOG_DIR = directory if LoggerFactory.log_share: @@ -123,30 +125,36 @@ def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None return logging.StreamHandler() if not log_dir: - log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) + log_file = os.path.join( + LoggerFactory.LOG_DIR, "{}.log".format(class_name) + ) else: log_file = os.path.join(log_dir, "{}.log".format(class_name)) else: - log_file = os.path.join(log_dir, "fate_flow_{}.log".format( - log_type) if level == LoggerFactory.LEVEL else 'fate_flow_{}_error.log'.format(log_type)) + log_file = os.path.join( + log_dir, + "fate_flow_{}.log".format(log_type) + if level == LoggerFactory.LEVEL + else "fate_flow_{}_error.log".format(log_type), + ) job_id = job_id or os.getenv("FATE_JOB_ID") if job_id: - formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", job_id)) + formatter = logging.Formatter( + LoggerFactory.LOG_FORMAT.replace("jobId", job_id) + ) else: - formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", "Server")) + formatter = logging.Formatter( + LoggerFactory.LOG_FORMAT.replace("jobId", "Server") + ) os.makedirs(os.path.dirname(log_file), exist_ok=True) if LoggerFactory.log_share: - handler = ROpenHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) + handler = ROpenHandler( + log_file, when="D", interval=1, backupCount=14, delay=True + ) else: - handler = TimedRotatingFileHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) + handler = TimedRotatingFileHandler( + log_file, when="D", interval=1, backupCount=14, delay=True + ) if level: handler.level = level @@ -176,13 +184,18 @@ def assemble_global_handler(logger): for level in LoggerFactory.levels: if level >= LoggerFactory.LEVEL: level_logger_name = logging._levelToName[level] - logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) + logger.addHandler( + LoggerFactory.get_global_handler(level_logger_name, level) + ) if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: for level in LoggerFactory.levels: if level >= LoggerFactory.LEVEL: level_logger_name = logging._levelToName[level] logger.addHandler( - LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR)) + LoggerFactory.get_global_handler( + level_logger_name, level, LoggerFactory.PARENT_LOG_DIR + ) + ) def setDirectory(directory=None): @@ -197,7 +210,7 @@ def getLogger(className=None, useLevelFile=False): if className is None: frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) - className = 'stat' + className = "stat" return LoggerFactory.get_logger(className) diff --git a/python/fate/arch/common/path_utils.py b/python/fate/arch/common/path_utils.py index 78cfd7cd16..9d80164a11 100644 --- a/python/fate/arch/common/path_utils.py +++ b/python/fate/arch/common/path_utils.py @@ -16,12 +16,12 @@ import os -from fate_arch.common import file_utils +from .file_utils import load_yaml_conf def get_data_table_count(path): config_path = os.path.join(path, "config.yaml") - config = file_utils.load_yaml_conf(conf_path=config_path) + config = load_yaml_conf(conf_path=config_path) count = 0 if config: if config.get("type") != "vision": diff --git a/python/fate/arch/common/profile.py b/python/fate/arch/common/profile.py index 845b364cd7..e9081c71ca 100644 --- a/python/fate/arch/common/profile.py +++ b/python/fate/arch/common/profile.py @@ -14,15 +14,15 @@ # limitations under the License. # import hashlib +import inspect import time import typing +from functools import wraps import beautifultable -from fate_arch.common.log import getLogger -import inspect -from functools import wraps -from fate_arch.abc import CTableABC +from ..abc import CTableABC +from .log import getLogger profile_logger = getLogger("PROFILING") _PROFILE_LOG_ENABLED = False @@ -36,7 +36,7 @@ def __init__(self): self.total_time = 0.0 self.max_time = 0.0 - def union(self, other: '_TimerItem'): + def union(self, other: "_TimerItem"): self.count += other.count self.total_time += other.total_time if self.max_time < other.max_time: @@ -78,12 +78,16 @@ def __init__(self, function_name: str, function_stack_list): self._start = time.time() function_stack = "\n".join(function_stack_list) - self._hash = hashlib.blake2b(function_stack.encode('utf-8'), digest_size=5).hexdigest() + self._hash = hashlib.blake2b( + function_stack.encode("utf-8"), digest_size=5 + ).hexdigest() if self._hash not in self._STATS: self._STATS[self._hash] = _ComputingTimerItem(function_name, function_stack) if _PROFILE_LOG_ENABLED: - profile_logger.debug(f"[computing#{self._hash}]function_stack: {' <-'.join(function_stack_list)}") + profile_logger.debug( + f"[computing#{self._hash}]function_stack: {' <-'.join(function_stack_list)}" + ) if _PROFILE_LOG_ENABLED: profile_logger.debug(f"[computing#{self._hash}]start") @@ -92,18 +96,30 @@ def done(self, function_string): elapse = time.time() - self._start self._STATS[self._hash].item.add(elapse) if _PROFILE_LOG_ENABLED: - profile_logger.debug(f"[computing#{self._hash}]done, elapse: {elapse}, function: {function_string}") + profile_logger.debug( + f"[computing#{self._hash}]done, elapse: {elapse}, function: {function_string}" + ) @classmethod def computing_statistics_table(cls, timer_aggregator: _TimerItem = None): - stack_table = beautifultable.BeautifulTable(110, precision=4, detect_numerics=False) - stack_table.columns.header = ["function", "n", "sum(s)", "mean(s)", "max(s)", "stack_hash", "stack"] + stack_table = beautifultable.BeautifulTable( + 110, precision=4, detect_numerics=False + ) + stack_table.columns.header = [ + "function", + "n", + "sum(s)", + "mean(s)", + "max(s)", + "stack_hash", + "stack", + ] stack_table.columns.alignment["stack"] = beautifultable.ALIGN_LEFT stack_table.columns.header.alignment = beautifultable.ALIGN_CENTER - stack_table.border.left = '' - stack_table.border.right = '' - stack_table.border.bottom = '' - stack_table.border.top = '' + stack_table.border.left = "" + stack_table.border.right = "" + stack_table.border.bottom = "" + stack_table.border.top = "" function_table = beautifultable.BeautifulTable(110) function_table.set_style(beautifultable.STYLE_COMPACT) @@ -112,7 +128,14 @@ def computing_statistics_table(cls, timer_aggregator: _TimerItem = None): aggregate = {} total = _TimerItem() for hash_id, timer in cls._STATS.items(): - stack_table.rows.append([timer.function_name, *timer.item.as_list(), hash_id, timer.function_stack]) + stack_table.rows.append( + [ + timer.function_name, + *timer.item.as_list(), + hash_id, + timer.function_stack, + ] + ) aggregate.setdefault(timer.function_name, _TimerItem()).union(timer.item) total.union(timer.item) @@ -148,20 +171,20 @@ def federation_statistics_table(cls, timer_aggregator: _TimerItem = None): get_table.rows.append([name, *item.as_list()]) total.union(item) get_table.rows.sort("sum(s)", reverse=True) - get_table.border.left = '' - get_table.border.right = '' - get_table.border.bottom = '' - get_table.border.top = '' + get_table.border.left = "" + get_table.border.right = "" + get_table.border.bottom = "" + get_table.border.top = "" remote_table = beautifultable.BeautifulTable(110) remote_table.columns.header = ["name", "n", "sum(s)", "mean(s)", "max(s)"] for name, item in cls._REMOTE_STATS.items(): remote_table.rows.append([name, *item.as_list()]) total.union(item) remote_table.rows.sort("sum(s)", reverse=True) - remote_table.border.left = '' - remote_table.border.right = '' - remote_table.border.bottom = '' - remote_table.border.top = '' + remote_table.border.left = "" + remote_table.border.right = "" + remote_table.border.bottom = "" + remote_table.border.top = "" base_table = beautifultable.BeautifulTable(120) base_table.rows.append(["get", get_table]) @@ -189,15 +212,19 @@ def __init__(self, name, full_name, tag, local, parties): def done(self, federation): self._end_time = time.time() self._REMOTE_STATS[self._full_name].add(self.elapse) - profile_logger.debug(f"[federation.remote.{self._full_name}.{self._tag}]" - f"{self._local_party}->{self._parties} done") + profile_logger.debug( + f"[federation.remote.{self._full_name}.{self._tag}]" + f"{self._local_party}->{self._parties} done" + ) if is_profile_remote_enable(): - federation.remote(v={"start_time": self._start_time, "end_time": self._end_time}, - name=self._name, - tag=profile_remote_tag(self._tag), - parties=self._parties, - gc=None) + federation.remote( + v={"start_time": self._start_time, "end_time": self._end_time}, + name=self._name, + tag=profile_remote_tag(self._tag), + parties=self._parties, + gc=None, + ) @property def elapse(self): @@ -220,15 +247,23 @@ def __init__(self, name, full_name, tag, local, parties): def done(self, federation): self._end_time = time.time() self._GET_STATS[self._full_name].add(self.elapse) - profile_logger.debug(f"[federation.get.{self._full_name}.{self._tag}]" - f"{self._local_party}<-{self._parties} done") + profile_logger.debug( + f"[federation.get.{self._full_name}.{self._tag}]" + f"{self._local_party}<-{self._parties} done" + ) if is_profile_remote_enable(): - remote_meta = federation.get(name=self._name, tag=profile_remote_tag(self._tag), parties=self._parties, - gc=None) + remote_meta = federation.get( + name=self._name, + tag=profile_remote_tag(self._tag), + parties=self._parties, + gc=None, + ) for party, meta in zip(self._parties, remote_meta): - profile_logger.debug(f"[federation.meta.{self._full_name}.{self._tag}]{self._local_party}<-{party}]" - f"meta={meta}") + profile_logger.debug( + f"[federation.meta.{self._full_name}.{self._tag}]{self._local_party}<-{party}]" + f"meta={meta}" + ) @property def elapse(self): @@ -236,7 +271,9 @@ def elapse(self): def federation_remote_timer(name, full_name, tag, local, parties): - profile_logger.debug(f"[federation.remote.{full_name}.{tag}]{local}->{parties} start") + profile_logger.debug( + f"[federation.remote.{full_name}.{tag}]{local}->{parties} start" + ) return _FederationRemoteTimer(name, full_name, tag, local, parties) @@ -262,9 +299,15 @@ def profile_ends(): timer_aggregator = _TimerItem() computing_timer_aggregator = _TimerItem() federation_timer_aggregator = _TimerItem() - computing_base_table, computing_detailed_table = _ComputingTimer.computing_statistics_table( - timer_aggregator=computing_timer_aggregator) - federation_base_table = _FederationTimer.federation_statistics_table(timer_aggregator=federation_timer_aggregator) + ( + computing_base_table, + computing_detailed_table, + ) = _ComputingTimer.computing_statistics_table( + timer_aggregator=computing_timer_aggregator + ) + federation_base_table = _FederationTimer.federation_statistics_table( + timer_aggregator=federation_timer_aggregator + ) timer_aggregator.union(computing_timer_aggregator) timer_aggregator.union(federation_timer_aggregator) @@ -278,10 +321,12 @@ def profile_ends(): federation_timer_aggregator.total_time, federation_timer_aggregator.total_time / profile_total_time, computing_timer_aggregator.total_time, - computing_timer_aggregator.total_time / profile_total_time + computing_timer_aggregator.total_time / profile_total_time, ) ) - profile_logger.info(f"\nComputing:\n{computing_base_table}\n\nFederation:\n{federation_base_table}\n") + profile_logger.info( + f"\nComputing:\n{computing_base_table}\n\nFederation:\n{federation_base_table}\n" + ) profile_logger.debug(f"\nDetailed Computing:\n{computing_detailed_table}\n") global _PROFILE_LOG_ENABLED @@ -306,7 +351,9 @@ def _call_stack_strings(): call_stack_strings = [] frames = inspect.getouterframes(inspect.currentframe(), 10)[2:-2] for frame in frames: - call_stack_strings.append(f"[{frame.filename.split('/')[-1]}:{frame.lineno}]{frame.function}") + call_stack_strings.append( + f"[{frame.filename.split('/')[-1]}:{frame.lineno}]{frame.function}" + ) return call_stack_strings diff --git a/python/fate/arch/common/remote_status.py b/python/fate/arch/common/remote_status.py index 413b8f3fbb..5bacbda110 100644 --- a/python/fate/arch/common/remote_status.py +++ b/python/fate/arch/common/remote_status.py @@ -18,7 +18,7 @@ import concurrent.futures import typing -from fate_arch.common.log import getLogger +from .log import getLogger LOGGER = getLogger() diff --git a/python/fate/arch/common/string_utils.py b/python/fate/arch/common/string_utils.py index 049e3b4ee8..e1bdbd236b 100644 --- a/python/fate/arch/common/string_utils.py +++ b/python/fate/arch/common/string_utils.py @@ -20,9 +20,9 @@ def random_string(string_length=6): letters = string.ascii_lowercase - return ''.join(random.choice(letters) for _ in range(string_length)) + return "".join(random.choice(letters) for _ in range(string_length)) def random_number_string(string_length=6): letters = string.octdigits - return ''.join(random.choice(letters) for _ in range(string_length)) + return "".join(random.choice(letters) for _ in range(string_length)) diff --git a/python/fate/arch/common/versions.py b/python/fate/arch/common/versions.py index 8730cc2904..709650d30d 100644 --- a/python/fate/arch/common/versions.py +++ b/python/fate/arch/common/versions.py @@ -14,11 +14,11 @@ # limitations under the License. # import os +import typing import dotenv -import typing -from fate_arch.common.file_utils import get_project_base_directory +from .file_utils import get_project_base_directory def get_versions() -> typing.Mapping[str, typing.Any]: diff --git a/python/fate/arch/computing/__init__.py b/python/fate/arch/computing/__init__.py index 8f217933c0..ac0b30f9f9 100644 --- a/python/fate/arch/computing/__init__.py +++ b/python/fate/arch/computing/__init__.py @@ -15,7 +15,7 @@ # -from fate_arch.computing._type import ComputingEngine -from fate_arch.computing._util import is_table +from ._type import ComputingEngine +from ._util import is_table -__all__ = ['is_table', 'ComputingEngine'] +__all__ = ["is_table", "ComputingEngine"] diff --git a/python/fate/arch/computing/_type.py b/python/fate/arch/computing/_type.py index 31e0f2143a..16d48ac831 100644 --- a/python/fate/arch/computing/_type.py +++ b/python/fate/arch/computing/_type.py @@ -16,7 +16,7 @@ class ComputingEngine(object): - EGGROLL = 'EGGROLL' - SPARK = 'SPARK' - LINKIS_SPARK = 'LINKIS_SPARK' - STANDALONE = 'STANDALONE' + EGGROLL = "EGGROLL" + SPARK = "SPARK" + LINKIS_SPARK = "LINKIS_SPARK" + STANDALONE = "STANDALONE" diff --git a/python/fate/arch/computing/_util.py b/python/fate/arch/computing/_util.py index 7955f7f8b2..f315073b0a 100644 --- a/python/fate/arch/computing/_util.py +++ b/python/fate/arch/computing/_util.py @@ -15,7 +15,7 @@ # -from fate_arch.abc import CTableABC +from ..abc import CTableABC def is_table(v): diff --git a/python/fate/arch/computing/eggroll/__init__.py b/python/fate/arch/computing/eggroll/__init__.py index c65fbbe1a9..b37a08c0a0 100644 --- a/python/fate/arch/computing/eggroll/__init__.py +++ b/python/fate/arch/computing/eggroll/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # -from fate_arch.computing.eggroll._table import Table -from fate_arch.computing.eggroll._csession import CSession +from ._csession import CSession +from ._table import Table -__all__ = ['Table', 'CSession'] +__all__ = ["Table", "CSession"] diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/eggroll/_csession.py index 22b698ad37..776a31405e 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/eggroll/_csession.py @@ -18,11 +18,11 @@ from eggroll.core.session import session_init from eggroll.roll_pair.roll_pair import runtime_init -from fate_arch.abc import AddressABC, CSessionABC -from fate_arch.common.base_utils import fate_uuid -from fate_arch.common.log import getLogger -from fate_arch.common.profile import computing_profile -from fate_arch.computing.eggroll import Table +from ...abc import AddressABC, CSessionABC +from ...common.base_utils import fate_uuid +from ...common.log import getLogger +from ...common.profile import computing_profile +from ...computing.eggroll import Table LOGGER = getLogger() @@ -49,21 +49,21 @@ def session_id(self): @computing_profile def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs): - from fate_arch.common.address import EggRollAddress - from fate_arch.storage import EggRollStoreType + from ...common.address import EggRollAddress + from ...storage import EggRollStoreType if isinstance(address, EggRollAddress): options = kwargs.get("option", {}) options["total_partitions"] = partitions - options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB) + options["store_type"] = kwargs.get( + "store_type", EggRollStoreType.ROLLPAIR_LMDB + ) options["create_if_missing"] = False rp = self._rpc.load( namespace=address.namespace, name=address.name, options=options ) if rp is None or rp.get_partitions() == 0: - raise RuntimeError( - f"no exists: {address.name}, {address.namespace}" - ) + raise RuntimeError(f"no exists: {address.name}, {address.namespace}") if options["store_type"] != EggRollStoreType.ROLLPAIR_IN_MEMORY: rp = rp.save_as( @@ -77,11 +77,12 @@ def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs): table.schema = schema return table - from fate_arch.common.address import PathAddress + from ...common.address import PathAddress if isinstance(address, PathAddress): - from fate_arch.computing.non_distributed import LocalData - from fate_arch.computing import ComputingEngine + from ...computing import ComputingEngine + from ...computing.non_distributed import LocalData + return LocalData(address.path, engine=ComputingEngine.EGGROLL) raise NotImplementedError( @@ -115,5 +116,7 @@ def destroy(self): try: self.stop() except Exception as e: - LOGGER.warning(f"stop storage session {self.session_id} failed, try to kill", e) + LOGGER.warning( + f"stop storage session {self.session_id} failed, try to kill", e + ) self.kill() diff --git a/python/fate/arch/computing/eggroll/_table.py b/python/fate/arch/computing/eggroll/_table.py index a33af09302..36c03cd675 100644 --- a/python/fate/arch/computing/eggroll/_table.py +++ b/python/fate/arch/computing/eggroll/_table.py @@ -17,16 +17,15 @@ import typing -from fate_arch.abc import CTableABC -from fate_arch.common import log -from fate_arch.common.profile import computing_profile -from fate_arch.computing._type import ComputingEngine +from ...abc import CTableABC +from ...common import log +from ...common.profile import computing_profile +from .._type import ComputingEngine LOGGER = log.getLogger() class Table(CTableABC): - def __init__(self, rp): self._rp = rp self._engine = ComputingEngine.EGGROLL @@ -47,20 +46,32 @@ def copy(self): @computing_profile def save(self, address, partitions, schema: dict, **kwargs): options = kwargs.get("options", {}) - from fate_arch.common.address import EggRollAddress - from fate_arch.storage import EggRollStoreType + from ...common.address import EggRollAddress + from ...storage import EggRollStoreType + if isinstance(address, EggRollAddress): - options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB) - self._rp.save_as(name=address.name, namespace=address.namespace, partition=partitions, options=options) + options["store_type"] = kwargs.get( + "store_type", EggRollStoreType.ROLLPAIR_LMDB + ) + self._rp.save_as( + name=address.name, + namespace=address.namespace, + partition=partitions, + options=options, + ) schema.update(self.schema) return - from fate_arch.common.address import PathAddress + from ...common.address import PathAddress + if isinstance(address, PathAddress): - from fate_arch.computing.non_distributed import LocalData + from ...computing.non_distributed import LocalData + return LocalData(address.path) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend") + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll backend" + ) @computing_profile def collect(self, **kwargs) -> list: @@ -95,14 +106,22 @@ def applyPartitions(self, func): return Table(self._rp.collapse_partitions(func)) @computing_profile - def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs): + def mapPartitions( + self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs + ): if use_previous_behavior is True: - LOGGER.warning(f"please use `applyPartitions` instead of `mapPartitions` " - f"if the previous behavior was expected. " - f"The previous behavior will not work in future") + LOGGER.warning( + f"please use `applyPartitions` instead of `mapPartitions` " + f"if the previous behavior was expected. " + f"The previous behavior will not work in future" + ) return self.applyPartitions(func) - return Table(self._rp.map_partitions(func, options={"shuffle": not preserves_partitioning})) + return Table( + self._rp.map_partitions( + func, options={"shuffle": not preserves_partitioning} + ) + ) @computing_profile def mapReducePartitions(self, mapper, reducer, **kwargs): @@ -110,14 +129,18 @@ def mapReducePartitions(self, mapper, reducer, **kwargs): @computing_profile def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs): - return Table(self._rp.map_partitions_with_index(func, options={"shuffle": not preserves_partitioning})) + return Table( + self._rp.map_partitions_with_index( + func, options={"shuffle": not preserves_partitioning} + ) + ) @computing_profile def reduce(self, func, **kwargs): return self._rp.reduce(func) @computing_profile - def join(self, other: 'Table', func, **kwargs): + def join(self, other: "Table", func, **kwargs): return Table(self._rp.join(other._rp, func=func)) @computing_profile @@ -125,14 +148,22 @@ def glom(self, **kwargs): return Table(self._rp.glom()) @computing_profile - def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None): + def sample( + self, + *, + fraction: typing.Optional[float] = None, + num: typing.Optional[int] = None, + seed=None, + ): if fraction is not None: return Table(self._rp.sample(fraction=fraction, seed=seed)) if num is not None: total = self._rp.count() if num > total: - raise ValueError(f"not enough data to sample, own {total} but required {num}") + raise ValueError( + f"not enough data to sample, own {total} but required {num}" + ) frac = num / float(total) while True: @@ -150,10 +181,12 @@ def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optiona return Table(sampled_table) - raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}") + raise ValueError( + f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}" + ) @computing_profile - def subtractByKey(self, other: 'Table', **kwargs): + def subtractByKey(self, other: "Table", **kwargs): return Table(self._rp.subtract_by_key(other._rp)) @computing_profile @@ -161,7 +194,7 @@ def filter(self, func, **kwargs): return Table(self._rp.filter(func)) @computing_profile - def union(self, other: 'Table', func=lambda v1, v2: v1, **kwargs): + def union(self, other: "Table", func=lambda v1, v2: v1, **kwargs): return Table(self._rp.union(other._rp, func=func)) @computing_profile diff --git a/python/fate/arch/computing/eggroll/_table.pyi b/python/fate/arch/computing/eggroll/_table.pyi index d0b8752e36..46dc13eef3 100644 --- a/python/fate/arch/computing/eggroll/_table.pyi +++ b/python/fate/arch/computing/eggroll/_table.pyi @@ -14,14 +14,12 @@ # limitations under the License. # from eggroll.roll_pair.roll_pair import RollPair -from fate_arch.abc import AddressABC, CTableABC +from ...abc import AddressABC, CTableABC # noinspection PyAbstractClass class Table(CTableABC): - def __init__(self, rp: RollPair): self._rp: RollPair = ... ... - def save(self, address: AddressABC, partitions: int, schema: dict, **kwargs): ... diff --git a/python/fate/arch/computing/non_distributed.py b/python/fate/arch/computing/non_distributed.py index ecd8913bf2..28ee93c6f6 100644 --- a/python/fate/arch/computing/non_distributed.py +++ b/python/fate/arch/computing/non_distributed.py @@ -15,7 +15,7 @@ # -class LocalData(): +class LocalData: def __init__(self, path, engine=None): self.path = path self.schema = {"header": [], "sid_name": "id"} diff --git a/python/fate/arch/computing/spark/__init__.py b/python/fate/arch/computing/spark/__init__.py index 830a8cc6f7..0936508c78 100644 --- a/python/fate/arch/computing/spark/__init__.py +++ b/python/fate/arch/computing/spark/__init__.py @@ -14,9 +14,17 @@ # limitations under the License. # -from fate_arch.computing.spark._csession import CSession -from fate_arch.computing.spark._table import Table, from_hdfs, from_rdd, from_hive, from_localfs -from fate_arch.computing.spark._materialize import get_storage_level, materialize +from ._csession import CSession +from ._materialize import get_storage_level, materialize +from ._table import Table, from_hdfs, from_hive, from_localfs, from_rdd -__all__ = ['Table', 'CSession', 'from_hdfs', 'from_hive', 'from_localfs', 'from_rdd', - 'get_storage_level', 'materialize'] +__all__ = [ + "Table", + "CSession", + "from_hdfs", + "from_hive", + "from_localfs", + "from_rdd", + "get_storage_level", + "materialize", +] diff --git a/python/fate/arch/computing/spark/_csession.py b/python/fate/arch/computing/spark/_csession.py index 800847c552..91a301a3b0 100644 --- a/python/fate/arch/computing/spark/_csession.py +++ b/python/fate/arch/computing/spark/_csession.py @@ -16,11 +16,10 @@ from typing import Iterable -from fate_arch.abc import AddressABC -from fate_arch.abc import CSessionABC -from fate_arch.common.address import LocalFSAddress -from fate_arch.computing.spark._table import from_hdfs, from_rdd, from_hive, from_localfs -from fate_arch.common import log +from ...abc import AddressABC, CSessionABC +from ...common import log +from ...common.address import LocalFSAddress +from ._table import from_hdfs, from_hive, from_localfs, from_rdd LOGGER = log.getLogger() @@ -34,27 +33,27 @@ def __init__(self, session_id): self._session_id = session_id def load(self, address: AddressABC, partitions, schema, **kwargs): - from fate_arch.common.address import HDFSAddress + from ...common.address import HDFSAddress + if isinstance(address, HDFSAddress): table = from_hdfs( paths=f"{address.name_node}/{address.path}", partitions=partitions, - in_serialized=kwargs.get( - "in_serialized", - True), - id_delimiter=kwargs.get( - "id_delimiter", - ',')) + in_serialized=kwargs.get("in_serialized", True), + id_delimiter=kwargs.get("id_delimiter", ","), + ) table.schema = schema return table - from fate_arch.common.address import PathAddress + from ...common.address import PathAddress + if isinstance(address, PathAddress): - from fate_arch.computing.non_distributed import LocalData - from fate_arch.computing import ComputingEngine + from ...computing import ComputingEngine + from ...computing.non_distributed import LocalData + return LocalData(address.path, engine=ComputingEngine.SPARK) - from fate_arch.common.address import HiveAddress, LinkisHiveAddress + from ...common.address import HiveAddress, LinkisHiveAddress if isinstance(address, (HiveAddress, LinkisHiveAddress)): table = from_hive( @@ -67,9 +66,11 @@ def load(self, address: AddressABC, partitions, schema, **kwargs): if isinstance(address, LocalFSAddress): table = from_localfs( - paths=address.path, partitions=partitions, in_serialized=kwargs.get( - "in_serialized", True), id_delimiter=kwargs.get( - "id_delimiter", ',')) + paths=address.path, + partitions=partitions, + in_serialized=kwargs.get("in_serialized", True), + id_delimiter=kwargs.get("id_delimiter", ","), + ) table.schema = schema return table @@ -80,6 +81,7 @@ def load(self, address: AddressABC, partitions, schema, **kwargs): def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs): # noinspection PyPackageRequirements from pyspark import SparkContext + _iter = data if include_key else enumerate(data) rdd = SparkContext.getOrCreate().parallelize(_iter, partition) return from_rdd(rdd) diff --git a/python/fate/arch/computing/spark/_table.py b/python/fate/arch/computing/spark/_table.py index 0a25c34431..cf90e2e43b 100644 --- a/python/fate/arch/computing/spark/_table.py +++ b/python/fate/arch/computing/spark/_table.py @@ -14,20 +14,19 @@ # limitations under the License. # +import typing import uuid from itertools import chain -import typing import pyspark - from pyspark.rddsampler import RDDSamplerBase - -from fate_arch.abc import CTableABC -from fate_arch.common import log, hdfs_utils, hive_utils -from fate_arch.common.profile import computing_profile -from fate_arch.computing.spark._materialize import materialize, unmaterialize from scipy.stats import hypergeom -from fate_arch.computing._type import ComputingEngine + +from ...abc import CTableABC +from ...common import hdfs_utils, hive_utils, log +from ...common.profile import computing_profile +from .._type import ComputingEngine +from ._materialize import materialize, unmaterialize LOGGER = log.getLogger() @@ -59,7 +58,7 @@ def copy(self): @computing_profile def save(self, address, partitions, schema, **kwargs): - from fate_arch.common.address import HDFSAddress + from ...common.address import HDFSAddress if isinstance(address, HDFSAddress): self._rdd.map(lambda x: hdfs_utils.serialize(x[0], x[1])).repartition( @@ -68,7 +67,7 @@ def save(self, address, partitions, schema, **kwargs): schema.update(self.schema) return - from fate_arch.common.address import HiveAddress, LinkisHiveAddress + from ...common.address import HiveAddress, LinkisHiveAddress if isinstance(address, (HiveAddress, LinkisHiveAddress)): # df = ( @@ -77,12 +76,14 @@ def save(self, address, partitions, schema, **kwargs): # .toDF() # ) LOGGER.debug(f"partitions: {partitions}") - _repartition = self._rdd.map(lambda x: hive_utils.to_row(x[0], x[1])).repartition(partitions) + _repartition = self._rdd.map( + lambda x: hive_utils.to_row(x[0], x[1]) + ).repartition(partitions) _repartition.toDF().write.saveAsTable(f"{address.database}.{address.name}") schema.update(self.schema) return - from fate_arch.common.address import LocalFSAddress + from ...common.address import LocalFSAddress if isinstance(address, LocalFSAddress): self._rdd.map(lambda x: hdfs_utils.serialize(x[0], x[1])).repartition( @@ -133,7 +134,9 @@ def applyPartitions(self, func, **kwargs): @computing_profile def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs): return from_rdd( - self._rdd.mapPartitionsWithIndex(func, preservesPartitioning=preserves_partitioning) + self._rdd.mapPartitionsWithIndex( + func, preservesPartitioning=preserves_partitioning + ) ) @computing_profile @@ -212,13 +215,12 @@ def from_hdfs(paths: str, partitions, in_serialized=True, id_delimiter=None): from pyspark import SparkContext sc = SparkContext.getOrCreate() - fun = hdfs_utils.deserialize if in_serialized else lambda x: (x.partition(id_delimiter)[0], - x.partition(id_delimiter)[2]) - rdd = materialize( - sc.textFile(paths, partitions) - .map(fun) - .repartition(partitions) + fun = ( + hdfs_utils.deserialize + if in_serialized + else lambda x: (x.partition(id_delimiter)[0], x.partition(id_delimiter)[2]) ) + rdd = materialize(sc.textFile(paths, partitions).map(fun).repartition(partitions)) return Table(rdd=rdd) @@ -227,13 +229,12 @@ def from_localfs(paths: str, partitions, in_serialized=True, id_delimiter=None): from pyspark import SparkContext sc = SparkContext.getOrCreate() - fun = hdfs_utils.deserialize if in_serialized else lambda x: (x.partition(id_delimiter)[0], - x.partition(id_delimiter)[2]) - rdd = materialize( - sc.textFile(paths, partitions) - .map(fun) - .repartition(partitions) + fun = ( + hdfs_utils.deserialize + if in_serialized + else lambda x: (x.partition(id_delimiter)[0], x.partition(id_delimiter)[2]) ) + rdd = materialize(sc.textFile(paths, partitions).map(fun).repartition(partitions)) return Table(rdd=rdd) diff --git a/python/fate/arch/computing/spark/_table.pyi b/python/fate/arch/computing/spark/_table.pyi index ae3add6738..4f8d4ea56b 100644 --- a/python/fate/arch/computing/spark/_table.pyi +++ b/python/fate/arch/computing/spark/_table.pyi @@ -15,20 +15,16 @@ # from pyspark import RDD -from fate_arch.abc import AddressABC, CTableABC - +from ...abc import AddressABC, CTableABC # noinspection PyAbstractClass class Table(CTableABC): - def __init__(self, rdd: RDD): self._rdd: RDD = ... ... - def save(self, address: AddressABC, partitions: int, schema: dict, **kwargs): ... - def from_hdfs(paths: str, partitions, in_serialized, id_delimiter) -> Table: ... def from_hive(tb_name: str, db_name: str, partitions: int) -> Table: ... def from_rdd(rdd) -> Table: ... -def from_localfs(paths: str, partitions, in_serialized, id_delimiter) -> Table: ... \ No newline at end of file +def from_localfs(paths: str, partitions, in_serialized, id_delimiter) -> Table: ... diff --git a/python/fate/arch/computing/standalone/__init__.py b/python/fate/arch/computing/standalone/__init__.py index 27e7b781b0..b37a08c0a0 100644 --- a/python/fate/arch/computing/standalone/__init__.py +++ b/python/fate/arch/computing/standalone/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # -from fate_arch.computing.standalone._csession import CSession -from fate_arch.computing.standalone._table import Table +from ._csession import CSession +from ._table import Table -__all__ = ['Table', 'CSession'] +__all__ = ["Table", "CSession"] diff --git a/python/fate/arch/computing/standalone/_csession.py b/python/fate/arch/computing/standalone/_csession.py index deb0c21501..1f9310d9de 100644 --- a/python/fate/arch/computing/standalone/_csession.py +++ b/python/fate/arch/computing/standalone/_csession.py @@ -15,11 +15,11 @@ # from collections import Iterable -from fate_arch._standalone import Session -from fate_arch.abc import AddressABC, CSessionABC -from fate_arch.common.base_utils import fate_uuid -from fate_arch.common.log import getLogger -from fate_arch.computing.standalone._table import Table +from ..._standalone import Session +from ...abc import AddressABC, CSessionABC +from ...common.base_utils import fate_uuid +from ...common.log import getLogger +from ._table import Table LOGGER = getLogger() @@ -38,8 +38,8 @@ def session_id(self): return self._session.session_id def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs): - from fate_arch.common.address import StandaloneAddress - from fate_arch.storage import StandaloneStoreType + from ...common.address import StandaloneAddress + from ...storage import StandaloneStoreType if isinstance(address, StandaloneAddress): raw_table = self._session.load(address.name, address.namespace) @@ -54,11 +54,12 @@ def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs): table.schema = schema return table - from fate_arch.common.address import PathAddress + from ...common.address import PathAddress if isinstance(address, PathAddress): - from fate_arch.computing.non_distributed import LocalData - from fate_arch.computing import ComputingEngine + from ...computing import ComputingEngine + from ...computing.non_distributed import LocalData + return LocalData(address.path, engine=ComputingEngine.STANDALONE) raise NotImplementedError( f"address type {type(address)} not supported with standalone backend" @@ -89,5 +90,7 @@ def destroy(self): try: self.stop() except Exception as e: - LOGGER.warning(f"stop storage session {self.session_id} failed, try to kill", e) + LOGGER.warning( + f"stop storage session {self.session_id} failed, try to kill", e + ) self.kill() diff --git a/python/fate/arch/computing/standalone/_table.py b/python/fate/arch/computing/standalone/_table.py index 77f5967be1..4f36330e5d 100644 --- a/python/fate/arch/computing/standalone/_table.py +++ b/python/fate/arch/computing/standalone/_table.py @@ -17,10 +17,10 @@ import itertools import typing -from fate_arch.abc import CTableABC -from fate_arch.common import log -from fate_arch.common.profile import computing_profile -from fate_arch.computing._type import ComputingEngine +from ...abc import CTableABC +from ...common import log +from ...common.profile import computing_profile +from .._type import ComputingEngine LOGGER = log.getLogger() @@ -48,7 +48,7 @@ def copy(self): @computing_profile def save(self, address, partitions, schema, **kwargs): - from fate_arch.common.address import StandaloneAddress + from ...common.address import StandaloneAddress if isinstance(address, StandaloneAddress): self._table.save_as( @@ -60,10 +60,10 @@ def save(self, address, partitions, schema, **kwargs): schema.update(self.schema) return - from fate_arch.common.address import PathAddress + from ...common.address import PathAddress if isinstance(address, PathAddress): - from fate_arch.computing.non_distributed import LocalData + from ...computing.non_distributed import LocalData return LocalData(address.path) raise NotImplementedError( diff --git a/python/fate/arch/computing/standalone/_table.pyi b/python/fate/arch/computing/standalone/_table.pyi index 2f46b30694..3335ccbed7 100644 --- a/python/fate/arch/computing/standalone/_table.pyi +++ b/python/fate/arch/computing/standalone/_table.pyi @@ -14,15 +14,12 @@ # limitations under the License. # - -from fate_arch._standalone import Table as StandaloneTable -from fate_arch.abc import AddressABC, CTableABC - +from ..._standalone import Table as StandaloneTable +from ...abc import AddressABC, CTableABC # noinspection PyAbstractClass class Table(CTableABC): def __init__(self, table: StandaloneTable): self._table = table ... - def save(self, address: AddressABC, partitions: int, schema: dict, **kwargs): ... diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index 6c07206141..0f6a322059 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -1,7 +1,3 @@ -from fate_arch.federation._type import FederationEngine -from fate_arch.federation._type import FederationDataType +from ._type import FederationDataType, FederationEngine -__all__ = [ - "FederationEngine", - "FederationDataType" -] +__all__ = ["FederationEngine", "FederationDataType"] diff --git a/python/fate/arch/federation/_datastream.py b/python/fate/arch/federation/_datastream.py index 97ff301a80..f3861b119d 100644 --- a/python/fate/arch/federation/_datastream.py +++ b/python/fate/arch/federation/_datastream.py @@ -16,8 +16,8 @@ import io -import sys import json +import sys # Datastream is a wraper of StringIO, it receives kv pairs and dump it to json string diff --git a/python/fate/arch/federation/_federation.py b/python/fate/arch/federation/_federation.py index 833a32520f..67f824b120 100644 --- a/python/fate/arch/federation/_federation.py +++ b/python/fate/arch/federation/_federation.py @@ -18,15 +18,15 @@ import json import sys import typing -from pickle import dumps as p_dumps, loads as p_loads +from pickle import dumps as p_dumps +from pickle import loads as p_loads -from fate_arch.abc import CTableABC -from fate_arch.abc import FederationABC, GarbageCollectionABC -from fate_arch.common import Party -from fate_arch.common.log import getLogger -from fate_arch.federation import FederationDataType -from fate_arch.federation._datastream import Datastream -from fate_arch.session import computing_session +from ..abc import CTableABC, FederationABC, GarbageCollectionABC +from ..common import Party +from ..common.log import getLogger +from ..federation import FederationDataType +from ..federation._datastream import Datastream +from ..session import computing_session LOGGER = getLogger() @@ -42,28 +42,21 @@ def _get_splits(obj, max_message_size): return obj, num_slice else: _max_size = max_message_size - kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)] + kv = [ + (i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) + for i in range(num_slice) + ] return kv, num_slice class FederationBase(FederationABC): @staticmethod def from_conf( - federation_session_id: str, - party: Party, - runtime_conf: dict, - **kwargs + federation_session_id: str, party: Party, runtime_conf: dict, **kwargs ): raise NotImplementedError() - def __init__( - self, - session_id, - party: Party, - mq, - max_message_size, - conf=None - ): + def __init__(self, session_id, party: Party, mq, max_message_size, conf=None): self._session_id = session_id self._party = party self._mq = mq @@ -85,7 +78,7 @@ def destroy(self, parties): raise NotImplementedError() def get( - self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC + self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC ) -> typing.List: log_str = f"[federation.get](name={name}, tag={tag}, parties={parties})" LOGGER.debug(f"[{log_str}]start to get") @@ -96,7 +89,9 @@ def get( ] if _name_dtype_keys[0] not in self._name_dtype_map: - party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG) + party_topic_infos = self._get_party_topic_infos( + parties, dtype=NAME_DTYPE_TAG + ) channel_infos = self._get_channels(party_topic_infos=party_topic_infos) rtn_dtype = [] for i, info in enumerate(channel_infos): @@ -118,8 +113,13 @@ def get( dtype = rtn_dtype.get("dtype", None) partitions = rtn_dtype.get("partitions", None) - if dtype == FederationDataType.TABLE or dtype == FederationDataType.SPLIT_OBJECT: - party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions) + if ( + dtype == FederationDataType.TABLE + or dtype == FederationDataType.SPLIT_OBJECT + ): + party_topic_infos = self._get_party_topic_infos( + parties, name, partitions=partitions + ) for i in range(len(party_topic_infos)): party = parties[i] role = party.role @@ -134,10 +134,12 @@ def get( dst_role=role, topic_infos=topic_infos, mq=self._mq, - conf=self._conf + conf=self._conf, ) - table = computing_session.parallelize(range(partitions), partitions, include_key=False) + table = computing_session.parallelize( + range(partitions), partitions, include_key=False + ) table = table.mapPartitionsWithIndex(receive_func) # add gc @@ -149,7 +151,9 @@ def get( if dtype == FederationDataType.TABLE: rtn.append(table) else: - obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0]))) + obj_bytes = b"".join( + map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])) + ) obj = p_loads(obj_bytes) rtn.append(obj) else: @@ -166,12 +170,12 @@ def get( return rtn def remote( - self, - v, - name: str, - tag: str, - parties: typing.List[Party], - gc: GarbageCollectionABC, + self, + v, + name: str, + tag: str, + parties: typing.List[Party], + gc: GarbageCollectionABC, ) -> typing.NoReturn: log_str = f"[federation.remote](name={name}, tag={tag}, parties={parties})" @@ -181,14 +185,21 @@ def remote( ] if _name_dtype_keys[0] not in self._name_dtype_map: - party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG) + party_topic_infos = self._get_party_topic_infos( + parties, dtype=NAME_DTYPE_TAG + ) channel_infos = self._get_channels(party_topic_infos=party_topic_infos) if not isinstance(v, CTableABC): v, num_slice = _get_splits(v, self._max_message_size) if num_slice > 1: - v = computing_session.parallelize(data=v, partition=1, include_key=True) - body = {"dtype": FederationDataType.SPLIT_OBJECT, "partitions": v.partitions} + v = computing_session.parallelize( + data=v, partition=1, include_key=True + ) + body = { + "dtype": FederationDataType.SPLIT_OBJECT, + "partitions": v.partitions, + } else: body = {"dtype": FederationDataType.OBJECT} @@ -216,7 +227,9 @@ def remote( f"[{log_str}]start to remote table, total_size={total_size}, partitions={partitions}" ) - party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions) + party_topic_infos = self._get_party_topic_infos( + parties, name, partitions=partitions + ) # add gc gc.add_gc_action(tag, v, "__del__", {}) @@ -229,7 +242,7 @@ def remote( src_role=self._party.role, mq=self._mq, max_message_size=self._max_message_size, - conf=self._conf + conf=self._conf, ) # noinspection PyProtectedMember v.mapPartitionsWithIndex(send_func) @@ -244,7 +257,7 @@ def remote( LOGGER.debug(f"[{log_str}]finish to remote") def _get_party_topic_infos( - self, parties: typing.List[Party], name=None, partitions=None, dtype=None + self, parties: typing.List[Party], name=None, partitions=None, dtype=None ) -> typing.List: topic_infos = [ self._get_or_create_topic(party, name, partitions, dtype) @@ -257,20 +270,18 @@ def _maybe_create_topic_and_replication(self, party, topic_suffix): raise NotImplementedError() def _get_or_create_topic( - self, party: Party, name=None, partitions=None, dtype=None + self, party: Party, name=None, partitions=None, dtype=None ) -> typing.Tuple: topic_key_list = [] topic_infos = [] if dtype is not None: - topic_key = _SPLIT_.join( - [party.role, party.party_id, dtype, dtype]) + topic_key = _SPLIT_.join([party.role, party.party_id, dtype, dtype]) topic_key_list.append(topic_key) else: if partitions is not None: for i in range(partitions): - topic_key = _SPLIT_.join( - [party.role, party.party_id, name, str(i)]) + topic_key = _SPLIT_.join([party.role, party.party_id, name, str(i)]) topic_key_list.append(topic_key) elif name is not None: topic_key = _SPLIT_.join([party.role, party.party_id, name]) @@ -283,7 +294,9 @@ def _get_or_create_topic( if topic_key not in self._topic_map: topic_key_splits = topic_key.split(_SPLIT_) topic_suffix = "-".join(topic_key_splits[2:]) - topic_pair = self._maybe_create_topic_and_replication(party, topic_suffix) + topic_pair = self._maybe_create_topic_and_replication( + party, topic_suffix + ) self._topic_map[topic_key] = topic_pair topic_pair = self._topic_map[topic_key] @@ -292,7 +305,15 @@ def _get_or_create_topic( return topic_infos def _get_channel( - self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None): + self, + topic_pair, + src_party_id, + src_role, + dst_party_id, + dst_role, + mq=None, + conf: dict = None, + ): raise NotImplementedError() def _get_channels(self, party_topic_infos): @@ -312,14 +333,22 @@ def _get_channels(self, party_topic_infos): dst_party_id=party_id, dst_role=role, mq=self._mq, - conf=self._conf + conf=self._conf, ) self._channels_map[topic_key] = info channel_infos.append(info) return channel_infos - def _get_channels_index(self, index, party_topic_infos, src_party_id, src_role, mq=None, conf: dict = None): + def _get_channels_index( + self, + index, + party_topic_infos, + src_party_id, + src_role, + mq=None, + conf: dict = None, + ): channel_infos = [] for e in party_topic_infos: # select specified topic_info for a party @@ -334,7 +363,7 @@ def _get_channels_index(self, index, party_topic_infos, src_party_id, src_role, dst_party_id=party_id, dst_role=role, mq=mq, - conf=conf + conf=conf, ) channel_infos.append(info) return channel_infos @@ -345,19 +374,19 @@ def _send_obj(self, name, tag, data, channel_infos): "content_type": "text/plain", "app_id": info._dst_party_id, "message_id": name, - "correlation_id": tag + "correlation_id": tag, } LOGGER.debug(f"[federation._send_obj]properties:{properties}.") info.produce(body=data, properties=properties) def _send_kv( - self, name, tag, data, channel_infos, partition_size, partitions, message_key + self, name, tag, data, channel_infos, partition_size, partitions, message_key ): headers = json.dumps( { "partition_size": partition_size, "partitions": partitions, - "message_key": message_key + "message_key": message_key, } ) for info in channel_infos: @@ -366,22 +395,22 @@ def _send_kv( "app_id": info._dst_party_id, "message_id": name, "correlation_id": tag, - "headers": headers + "headers": headers, } print(f"[federation._send_kv]info: {info}, properties: {properties}.") info.produce(body=data, properties=properties) def _get_partition_send_func( - self, - name, - tag, - partitions, - party_topic_infos, - src_party_id, - src_role, - mq, - max_message_size, - conf: dict, + self, + name, + tag, + partitions, + party_topic_infos, + src_party_id, + src_role, + mq, + max_message_size, + conf: dict, ): def _fn(index, kvs): return self._partition_send( @@ -401,22 +430,26 @@ def _fn(index, kvs): return _fn def _partition_send( - self, - index, - kvs, - name, - tag, - partitions, - party_topic_infos, - src_party_id, - src_role, - mq, - max_message_size, - conf: dict, + self, + index, + kvs, + name, + tag, + partitions, + party_topic_infos, + src_party_id, + src_role, + mq, + max_message_size, + conf: dict, ): channel_infos = self._get_channels_index( - index=index, party_topic_infos=party_topic_infos, src_party_id=src_party_id, src_role=src_role, mq=mq, - conf=conf + index=index, + party_topic_infos=party_topic_infos, + src_party_id=src_party_id, + src_role=src_role, + mq=mq, + conf=conf, ) datastream = Datastream() @@ -429,8 +462,8 @@ def _partition_send( el = {"k": p_dumps(k).hex(), "v": p_dumps(v).hex()} # roughly caculate the size of package to avoid serialization ;) if ( - datastream.get_size() + sys.getsizeof(el["k"]) + sys.getsizeof(el["v"]) - >= max_message_size + datastream.get_size() + sys.getsizeof(el["k"]) + sys.getsizeof(el["v"]) + >= max_message_size ): print( f"[federation._partition_send]The size of message is: {datastream.get_size()}" @@ -491,9 +524,7 @@ def _receive_obj(self, channel_info, name, tag): channel_info = self._query_receive_topic(channel_info) for id, properties, body in self._get_consume_message(channel_info): - LOGGER.debug( - f"[federation._receive_obj] properties: {properties}" - ) + LOGGER.debug(f"[federation._receive_obj] properties: {properties}") if properties["message_id"] != name or properties["correlation_id"] != tag: # todo: fix this LOGGER.warning( @@ -521,7 +552,16 @@ def _receive_obj(self, channel_info, name, tag): ) def _get_partition_receive_func( - self, name, tag, src_party_id, src_role, dst_party_id, dst_role, topic_infos, mq, conf: dict + self, + name, + tag, + src_party_id, + src_role, + dst_party_id, + dst_role, + topic_infos, + mq, + conf: dict, ): def _fn(index, kvs): return self._partition_receive( @@ -541,27 +581,29 @@ def _fn(index, kvs): return _fn def _partition_receive( - self, - index, - kvs, - name, - tag, - src_party_id, - src_role, - dst_party_id, - dst_role, - topic_infos, - mq, - conf: dict, + self, + index, + kvs, + name, + tag, + src_party_id, + src_role, + dst_party_id, + dst_role, + topic_infos, + mq, + conf: dict, ): topic_pair = topic_infos[index][1] - channel_info = self._get_channel(topic_pair=topic_pair, - src_party_id=src_party_id, - src_role=src_role, - dst_party_id=dst_party_id, - dst_role=dst_role, - mq=mq, - conf=conf) + channel_info = self._get_channel( + topic_pair=topic_pair, + src_party_id=src_party_id, + src_role=src_role, + dst_party_id=dst_party_id, + dst_role=dst_role, + mq=mq, + conf=conf, + ) message_key_cache = set() count = 0 @@ -573,10 +615,11 @@ def _partition_receive( while True: try: for id, properties, body in self._get_consume_message(channel_info): - print( - f"[federation._partition_receive] properties: {properties}." - ) - if properties["message_id"] != name or properties["correlation_id"] != tag: + print(f"[federation._partition_receive] properties: {properties}.") + if ( + properties["message_id"] != name + or properties["correlation_id"] != tag + ): # todo: fix this self._consume_ack(channel_info, id) print( @@ -601,7 +644,10 @@ def _partition_receive( data = json.loads(body.decode()) data_iter = ( - (p_loads(bytes.fromhex(el["k"])), p_loads(bytes.fromhex(el["v"]))) + ( + p_loads(bytes.fromhex(el["k"])), + p_loads(bytes.fromhex(el["v"])), + ) for el in data ) count += len(data) diff --git a/python/fate/arch/federation/_gc.py b/python/fate/arch/federation/_gc.py index 718bef24ae..74b175fb0f 100644 --- a/python/fate/arch/federation/_gc.py +++ b/python/fate/arch/federation/_gc.py @@ -17,8 +17,8 @@ import typing from collections import deque -from fate_arch.abc import GarbageCollectionABC -from fate_arch.common.log import getLogger +from ..abc import GarbageCollectionABC +from ..common.log import getLogger LOGGER = getLogger() diff --git a/python/fate/arch/federation/_nretry.py b/python/fate/arch/federation/_nretry.py index e9babd727a..066d99cf2e 100644 --- a/python/fate/arch/federation/_nretry.py +++ b/python/fate/arch/federation/_nretry.py @@ -16,18 +16,16 @@ import time -from fate_arch.common.log import getLogger +from ..common.log import getLogger LOGGER = getLogger() def nretry(func): - """retry connection - """ + """retry connection""" def wrapper(self, *args, **kwargs): - """wrapper - """ + """wrapper""" res = None exception = None for ntry in range(10): @@ -41,9 +39,7 @@ def wrapper(self, *args, **kwargs): time.sleep(1) if exception is not None: - LOGGER.debug( - f"failed", - exc_info=exception) + LOGGER.debug(f"failed", exc_info=exception) raise exception return res diff --git a/python/fate/arch/federation/_type.py b/python/fate/arch/federation/_type.py index 61e7f03dd2..d4e7239195 100644 --- a/python/fate/arch/federation/_type.py +++ b/python/fate/arch/federation/_type.py @@ -16,10 +16,10 @@ class FederationEngine(object): - EGGROLL = 'EGGROLL' - RABBITMQ = 'RABBITMQ' - STANDALONE = 'STANDALONE' - PULSAR = 'PULSAR' + EGGROLL = "EGGROLL" + RABBITMQ = "RABBITMQ" + STANDALONE = "STANDALONE" + PULSAR = "PULSAR" class FederationDataType(object): diff --git a/python/fate/arch/federation/eggroll/__init__.py b/python/fate/arch/federation/eggroll/__init__.py index 1d0c5dcd46..4982a8a2db 100644 --- a/python/fate/arch/federation/eggroll/__init__.py +++ b/python/fate/arch/federation/eggroll/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # -from fate_arch.federation.eggroll._federation import Federation +from ._federation import Federation -__all__ = ['Federation'] +__all__ = ["Federation"] diff --git a/python/fate/arch/federation/eggroll/_federation.py b/python/fate/arch/federation/eggroll/_federation.py index 54c3ca3774..d15778374f 100644 --- a/python/fate/arch/federation/eggroll/_federation.py +++ b/python/fate/arch/federation/eggroll/_federation.py @@ -21,10 +21,11 @@ from eggroll.roll_pair.roll_pair import RollPair from eggroll.roll_site.roll_site import RollSiteContext -from fate_arch.abc import FederationABC -from fate_arch.common.log import getLogger -from fate_arch.computing.eggroll import Table -from fate_arch.common import remote_status + +from ...abc import FederationABC +from ...common import remote_status +from ...common.log import getLogger +from ...computing.eggroll import Table LOGGER = getLogger() diff --git a/python/fate/arch/federation/eggroll/_federation.pyi b/python/fate/arch/federation/eggroll/_federation.pyi index 8e6168555e..9ee1b51e70 100644 --- a/python/fate/arch/federation/eggroll/_federation.pyi +++ b/python/fate/arch/federation/eggroll/_federation.pyi @@ -17,46 +17,67 @@ import typing from eggroll.roll_pair.roll_pair import RollPairContext from eggroll.roll_site.roll_site import RollSiteContext -from fate_arch.abc import GarbageCollectionABC -from fate_arch.common import Party +from ...abc import GarbageCollectionABC +from ...common import Party class Federation(object): - - def __init__(self, rp_ctx: RollPairContext, rs_session_id: str, party: Party, proxy_endpoint: str): + def __init__( + self, + rp_ctx: RollPairContext, + rs_session_id: str, + party: Party, + proxy_endpoint: str, + ): self._rsc: RollSiteContext = ... ... - - def get(self: Federation, name: str, tag: str, parties: typing.List[Party], - gc: GarbageCollectionABC) -> typing.List: ... - - def remote(self, v, name: str, tag: str, parties: typing.List[Party], - gc: GarbageCollectionABC) -> typing.NoReturn: ... - - -def _remote(v, - name: str, - tag: str, - parties: typing.List[typing.Tuple[str, str]], - rsc: RollSiteContext, - gc: GarbageCollectionABC) -> typing.NoReturn: ... - - -def _get(name: str, - tag: str, - parties: typing.List[typing.Tuple[str, str]], - rsc: RollSiteContext, - gc: GarbageCollectionABC) -> typing.List: ... - - -def _remote_tag_not_duplicate(name: str, tag: str, parties: typing.List[typing.Tuple[str, str]]): ... - - -def _push_with_exception_handle(rsc: RollSiteContext, v, name: str, tag: str, parties: typing.List[typing.Tuple[str, str]]): ... - - + def get( + self: Federation, + name: str, + tag: str, + parties: typing.List[Party], + gc: GarbageCollectionABC, + ) -> typing.List: ... + def remote( + self, + v, + name: str, + tag: str, + parties: typing.List[Party], + gc: GarbageCollectionABC, + ) -> typing.NoReturn: ... + +def _remote( + v, + name: str, + tag: str, + parties: typing.List[typing.Tuple[str, str]], + rsc: RollSiteContext, + gc: GarbageCollectionABC, +) -> typing.NoReturn: ... +def _get( + name: str, + tag: str, + parties: typing.List[typing.Tuple[str, str]], + rsc: RollSiteContext, + gc: GarbageCollectionABC, +) -> typing.List: ... +def _remote_tag_not_duplicate( + name: str, tag: str, parties: typing.List[typing.Tuple[str, str]] +): ... +def _push_with_exception_handle( + rsc: RollSiteContext, + v, + name: str, + tag: str, + parties: typing.List[typing.Tuple[str, str]], +): ... def _get_tag_not_duplicate(name: str, tag: str, party: typing.Tuple[str, str]): ... - - -def _get_value_post_process(v, name: str, tag: str, party: typing.Tuple[str, str], rsc: RollSiteContext, - gc: GarbageCollectionABC): ... +def _get_value_post_process( + v, + name: str, + tag: str, + party: typing.Tuple[str, str], + rsc: RollSiteContext, + gc: GarbageCollectionABC, +): ... diff --git a/python/fate/arch/federation/pulsar/__init__.py b/python/fate/arch/federation/pulsar/__init__.py index 0b1cee68ec..6bc3c951bc 100644 --- a/python/fate/arch/federation/pulsar/__init__.py +++ b/python/fate/arch/federation/pulsar/__init__.py @@ -1,4 +1,3 @@ +from ._federation import MQ, Federation, PulsarManager -from fate_arch.federation.pulsar._federation import Federation, MQ, PulsarManager - -__all__ = ['Federation', 'MQ', 'PulsarManager'] +__all__ = ["Federation", "MQ", "PulsarManager"] diff --git a/python/fate/arch/federation/pulsar/_federation.py b/python/fate/arch/federation/pulsar/_federation.py index 98195b25ae..d1dc9c029a 100644 --- a/python/fate/arch/federation/pulsar/_federation.py +++ b/python/fate/arch/federation/pulsar/_federation.py @@ -14,17 +14,16 @@ # limitations under the License. # -from fate_arch.common import Party -from fate_arch.common import file_utils -from fate_arch.common.log import getLogger -from fate_arch.federation._federation import FederationBase -from fate_arch.federation.pulsar._mq_channel import ( - MQChannel, - DEFAULT_TENANT, +from ...common import Party, file_utils +from ...common.log import getLogger +from ._federation import FederationBase +from ._mq_channel import ( DEFAULT_CLUSTER, DEFAULT_SUBSCRIPTION_NAME, + DEFAULT_TENANT, + MQChannel, ) -from fate_arch.federation.pulsar._pulsar_manager import PulsarManager +from ._pulsar_manager import PulsarManager LOGGER = getLogger() # default message max size in bytes = 1MB @@ -59,10 +58,7 @@ def __init__(self, tenant, namespace, send, receive): class Federation(FederationBase): @staticmethod def from_conf( - federation_session_id: str, - party: Party, - runtime_conf: dict, - **kwargs + federation_session_id: str, party: Party, runtime_conf: dict, **kwargs ): pulsar_config = kwargs["pulsar_config"] LOGGER.debug(f"pulsar_config: {pulsar_config}") @@ -75,14 +71,14 @@ def from_conf( tenant = pulsar_config.get("tenant", DEFAULT_TENANT) # max_message_sizeï¼› - max_message_size = int(pulsar_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE)) + max_message_size = int( + pulsar_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE) + ) - pulsar_run = runtime_conf.get( - "job_parameters", {}).get("pulsar_run", {}) + pulsar_run = runtime_conf.get("job_parameters", {}).get("pulsar_run", {}) LOGGER.debug(f"pulsar_run: {pulsar_run}") - max_message_size = int(pulsar_run.get( - "max_message_size", max_message_size)) + max_message_size = int(pulsar_run.get("max_message_size", max_message_size)) LOGGER.debug(f"set max message size to {max_message_size} Bytes") @@ -102,8 +98,7 @@ def from_conf( # init tenant tenant_info = pulsar_manager.get_tenant(tenant=tenant).json() if tenant_info.get("allowedClusters") is None: - pulsar_manager.create_tenant( - tenant=tenant, admins=[], clusters=[cluster]) + pulsar_manager.create_tenant(tenant=tenant, admins=[], clusters=[cluster]) route_table_path = pulsar_config.get("route_table") if route_table_path is None: @@ -111,9 +106,7 @@ def from_conf( route_table = file_utils.load_yaml_conf(conf_path=route_table_path) mq = MQ(host, port, route_table) - conf = pulsar_manager.runtime_config.get( - "connection", {} - ) + conf = pulsar_manager.runtime_config.get("connection", {}) LOGGER.debug(f"federation mode={mode}") @@ -127,12 +120,29 @@ def from_conf( cluster, tenant, conf, - mode + mode, ) - def __init__(self, session_id, party: Party, mq: MQ, pulsar_manager: PulsarManager, max_message_size, topic_ttl, - cluster, tenant, conf, mode): - super().__init__(session_id=session_id, party=party, mq=mq, max_message_size=max_message_size, conf=conf) + def __init__( + self, + session_id, + party: Party, + mq: MQ, + pulsar_manager: PulsarManager, + max_message_size, + topic_ttl, + cluster, + tenant, + conf, + mode, + ): + super().__init__( + session_id=session_id, + party=party, + mq=mq, + max_message_size=max_message_size, + conf=conf, + ) self._pulsar_manager = pulsar_manager self._topic_ttl = topic_ttl @@ -212,8 +222,7 @@ def _create_topic_by_client_mode(self, party, topic_suffix): ) # init pulsar namespace - namespaces = self._pulsar_manager.get_namespace( - self._tenant).json() + namespaces = self._pulsar_manager.get_namespace(self._tenant).json() # create namespace if f"{self._tenant}/{self._session_id}" not in namespaces: # append target cluster to the pulsar namespace @@ -268,44 +277,33 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): ) if party.party_id == self._party.party_id: - LOGGER.debug( - "connecting to local broker, skipping cluster creation" - ) + LOGGER.debug("connecting to local broker, skipping cluster creation") else: # init pulsar cluster - cluster = self._pulsar_manager.get_cluster( - party.party_id).json() + cluster = self._pulsar_manager.get_cluster(party.party_id).json() if ( - cluster.get("brokerServiceUrl", "") == "" - and cluster.get("brokerServiceUrlTls", "") == "" + cluster.get("brokerServiceUrl", "") == "" + and cluster.get("brokerServiceUrlTls", "") == "" ): LOGGER.debug( "pulsar cluster with name %s does not exist or broker url is empty, creating...", party.party_id, ) - remote_party = self._mq.route_table.get( - int(party.party_id), None - ) + remote_party = self._mq.route_table.get(int(party.party_id), None) # handle party does not exist in route table first if remote_party is None: - domain = self._mq.route_table.get( - "default").get("domain") + domain = self._mq.route_table.get("default").get("domain") host = f"{party.party_id}.{domain}" - port = self._mq.route_table.get("default").get( - "brokerPort", "6650" - ) + port = self._mq.route_table.get("default").get("brokerPort", "6650") sslPort = self._mq.route_table.get("default").get( "brokerSslPort", "6651" ) - proxy = self._mq.route_table.get( - "default").get("proxy", "") + proxy = self._mq.route_table.get("default").get("proxy", "") # fetch party info from the route table else: - host = self._mq.route_table.get(int(party.party_id)).get( - "host" - ) + host = self._mq.route_table.get(int(party.party_id)).get("host") port = self._mq.route_table.get(int(party.party_id)).get( "port", "6650" ) @@ -322,10 +320,10 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): proxy = f"pulsar+ssl://{proxy}" if self._pulsar_manager.create_cluster( - cluster_name=party.party_id, - broker_url=broker_url, - broker_url_tls=broker_url_tls, - proxy_url=proxy, + cluster_name=party.party_id, + broker_url=broker_url, + broker_url_tls=broker_url_tls, + proxy_url=proxy, ).ok: LOGGER.debug( "pulsar cluster with name: %s, broker_url: %s created", @@ -333,10 +331,10 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): broker_url, ) elif self._pulsar_manager.update_cluster( - cluster_name=party.party_id, - broker_url=broker_url, - broker_url_tls=broker_url_tls, - proxy_url=proxy, + cluster_name=party.party_id, + broker_url=broker_url, + broker_url_tls=broker_url_tls, + proxy_url=proxy, ).ok: LOGGER.debug( "pulsar cluster with name: %s, broker_url: %s updated", @@ -344,26 +342,23 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): broker_url, ) else: - error_message = ( - "unable to create pulsar cluster: %s".format( - party.party_id - ) + error_message = "unable to create pulsar cluster: %s".format( + party.party_id ) LOGGER.error(error_message) # just leave this alone. raise Exception(error_message) # update tenant - tenant_info = self._pulsar_manager.get_tenant( - self._tenant).json() + tenant_info = self._pulsar_manager.get_tenant(self._tenant).json() if party.party_id not in tenant_info["allowedClusters"]: tenant_info["allowedClusters"].append(party.party_id) if self._pulsar_manager.update_tenant( - self._tenant, - tenant_info.get("admins", []), - tenant_info.get( - "allowedClusters", - ), + self._tenant, + tenant_info.get("admins", []), + tenant_info.get( + "allowedClusters", + ), ).ok: LOGGER.debug( "successfully update tenant with cluster: %s", @@ -374,15 +369,14 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): # TODO: remove this for the loop # init pulsar namespace - namespaces = self._pulsar_manager.get_namespace( - self._tenant).json() + namespaces = self._pulsar_manager.get_namespace(self._tenant).json() # create namespace if f"{self._tenant}/{self._session_id}" not in namespaces: # append target cluster to the pulsar namespace clusters = [self._cluster] if ( - party.party_id != self._party.party_id - and party.party_id not in clusters + party.party_id != self._party.party_id + and party.party_id not in clusters ): clusters.append(party.party_id) @@ -431,7 +425,7 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): if party.party_id not in clusters: clusters.append(party.party_id) if self._pulsar_manager.set_clusters_to_namespace( - self._tenant, self._session_id, clusters + self._tenant, self._session_id, clusters ).ok: LOGGER.debug( "successfully set clusters: {} to pulsar namespace: {}".format( @@ -447,8 +441,16 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): return topic_pair - def _get_channel(self, topic_pair: _TopicPair, src_party_id, src_role, dst_party_id, dst_role, mq=None, - conf: dict = None): + def _get_channel( + self, + topic_pair: _TopicPair, + src_party_id, + src_role, + dst_party_id, + dst_role, + mq=None, + conf: dict = None, + ): return MQChannel( host=mq.host, port=mq.port, diff --git a/python/fate/arch/federation/pulsar/_mq_channel.py b/python/fate/arch/federation/pulsar/_mq_channel.py index dc540e6760..9fbc630bf2 100644 --- a/python/fate/arch/federation/pulsar/_mq_channel.py +++ b/python/fate/arch/federation/pulsar/_mq_channel.py @@ -17,8 +17,8 @@ import pulsar -from fate_arch.common import log -from fate_arch.federation._nretry import nretry +from ...common import log +from .._nretry import nretry LOGGER = log.getLogger() CHANNEL_TYPE_PRODUCER = "producer" @@ -37,19 +37,19 @@ class MQChannel(object): # TODO add credential to secure pulsar cluster def __init__( - self, - host, - port, - tenant, - namespace, - send_topic, - receive_topic, - src_party_id, - src_role, - dst_party_id, - dst_role, - credential=None, - extra_args: dict = None, + self, + host, + port, + tenant, + namespace, + send_topic, + receive_topic, + src_party_id, + src_role, + dst_party_id, + dst_role, + credential=None, + extra_args: dict = None, ): # "host:port" is used to connect the pulsar broker self._host = host @@ -97,8 +97,7 @@ def produce(self, body, properties): LOGGER.debug("send queue: {}".format(self._producer_send.topic())) LOGGER.debug("send data size: {}".format(len(body))) - message_id = self._producer_send.send( - content=body, properties=properties) + message_id = self._producer_send.send(content=body, properties=properties) if message_id is None: raise Exception("publish failed") @@ -109,15 +108,11 @@ def consume(self): self._get_or_create_consumer() try: - LOGGER.debug("receive topic: {}".format( - self._consumer_receive.topic())) - receive_timeout = self._consumer_config.get( - 'receive_timeout_millis', None) + LOGGER.debug("receive topic: {}".format(self._consumer_receive.topic())) + receive_timeout = self._consumer_config.get("receive_timeout_millis", None) if receive_timeout is not None: - LOGGER.debug( - f"receive timeout millis {receive_timeout}") - message = self._consumer_receive.receive( - timeout_millis=receive_timeout) + LOGGER.debug(f"receive timeout millis {receive_timeout}") + message = self._consumer_receive.receive(timeout_millis=receive_timeout) return message except Exception: self._consumer_receive.seek(pulsar.MessageId.earliest) @@ -169,8 +164,7 @@ def _get_or_create_producer(self): # if self._producer_conn is None: try: self._producer_conn = pulsar.Client( - service_url="pulsar://{}:{}".format( - self._host, self._port), + service_url="pulsar://{}:{}".format(self._host, self._port), operation_timeout_seconds=30, ) except Exception as e: @@ -189,16 +183,14 @@ def _get_or_create_producer(self): **self._producer_config, ) except Exception as e: - LOGGER.debug( - f"catch exception {e} in creating pulsar producer") + LOGGER.debug(f"catch exception {e} in creating pulsar producer") self._producer_conn = None def _get_or_create_consumer(self): if not self._check_consumer_alive(): try: self._consumer_conn = pulsar.Client( - service_url="pulsar://{}:{}".format( - self._host, self._port), + service_url="pulsar://{}:{}".format(self._host, self._port), operation_timeout_seconds=30, ) except Exception: @@ -221,8 +213,7 @@ def _get_or_create_consumer(self): self._consumer_receive.seek(self._latest_confirmed) except Exception as e: - LOGGER.debug( - f"catch exception {e} in creating pulsar consumer") + LOGGER.debug(f"catch exception {e} in creating pulsar consumer") self._consumer_conn.close() self._consumer_conn = None diff --git a/python/fate/arch/federation/pulsar/_pulsar_manager.py b/python/fate/arch/federation/pulsar/_pulsar_manager.py index d57d2ff377..8915df03cd 100644 --- a/python/fate/arch/federation/pulsar/_pulsar_manager.py +++ b/python/fate/arch/federation/pulsar/_pulsar_manager.py @@ -14,13 +14,14 @@ # limitations under the License. # -import logging import json -import requests +import logging +import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from fate_arch.common.log import getLogger + +from ...common.log import getLogger logger = getLogger() @@ -30,14 +31,14 @@ # sleep time equips to {BACKOFF_FACTOR} * (2 ** ({NUMBER_OF_TOTALRETRIES} - 1)) -CLUSTER = 'clusters/{}' -TENANT = 'tenants/{}' +CLUSTER = "clusters/{}" +TENANT = "tenants/{}" # APIs are refer to https://pulsar.apache.org/admin-rest-api/?version=2.7.0&apiversion=v2 -class PulsarManager(): +class PulsarManager: def __init__(self, host: str, port: str, runtime_config: dict = {}): self.service_url = "http://{}:{}/admin/v2/".format(host, port) self.runtime_config = runtime_config @@ -46,83 +47,99 @@ def __init__(self, host: str, port: str, runtime_config: dict = {}): def _create_session(self): # retry mechanism refers to # https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.Retry - retry = Retry(total=MAX_RETRIES, redirect=MAX_REDIRECT, - backoff_factor=BACKOFF_FACTOR) + retry = Retry( + total=MAX_RETRIES, redirect=MAX_REDIRECT, backoff_factor=BACKOFF_FACTOR + ) s = requests.Session() # initialize headers - s.headers.update({'Content-Type': 'application/json'}) + s.headers.update({"Content-Type": "application/json"}) http_adapter = HTTPAdapter(max_retries=retry) - s.mount('http://', http_adapter) - s.mount('https://', http_adapter) + s.mount("http://", http_adapter) + s.mount("https://", http_adapter) return s # allocator - def get_allocator(self, allocator: str = 'default'): + def get_allocator(self, allocator: str = "default"): session = self._create_session() response = session.get( - self.service_url + 'broker-stats/allocator-stats/{}'.format(allocator)) + self.service_url + "broker-stats/allocator-stats/{}".format(allocator) + ) return response # cluster - def get_cluster(self, cluster_name: str = ''): + def get_cluster(self, cluster_name: str = ""): session = self._create_session() - response = session.get( - self.service_url + CLUSTER.format(cluster_name)) + response = session.get(self.service_url + CLUSTER.format(cluster_name)) return response - def delete_cluster(self, cluster_name: str = ''): + def delete_cluster(self, cluster_name: str = ""): session = self._create_session() - response = session.delete( - self.service_url + CLUSTER.format(cluster_name)) + response = session.delete(self.service_url + CLUSTER.format(cluster_name)) return response # service_url need to provide "http://" prefix - def create_cluster(self, cluster_name: str, broker_url: str, service_url: str = '', - service_url_tls: str = '', broker_url_tls: str = '', - proxy_url: str = '', proxy_protocol: str = "SNI", peer_cluster_names: list = [], - ): + def create_cluster( + self, + cluster_name: str, + broker_url: str, + service_url: str = "", + service_url_tls: str = "", + broker_url_tls: str = "", + proxy_url: str = "", + proxy_protocol: str = "SNI", + peer_cluster_names: list = [], + ): # initialize data data = { - 'serviceUrl': service_url, - 'serviceUrlTls': service_url_tls, - 'brokerServiceUrl': broker_url, - 'brokerServiceUrlTls': broker_url_tls, - 'peerClusterNames': peer_cluster_names, - 'proxyServiceUrl': proxy_url, - 'proxyProtocol': proxy_protocol + "serviceUrl": service_url, + "serviceUrlTls": service_url_tls, + "brokerServiceUrl": broker_url, + "brokerServiceUrlTls": broker_url_tls, + "peerClusterNames": peer_cluster_names, + "proxyServiceUrl": proxy_url, + "proxyProtocol": proxy_protocol, } session = self._create_session() response = session.put( - self.service_url + CLUSTER.format(cluster_name), data=json.dumps(data)) + self.service_url + CLUSTER.format(cluster_name), data=json.dumps(data) + ) return response - def update_cluster(self, cluster_name: str, broker_url: str, service_url: str = '', - service_url_tls: str = '', broker_url_tls: str = '', - proxy_url: str = '', proxy_protocol: str = "SNI", peer_cluster_names: list = [], - ): + def update_cluster( + self, + cluster_name: str, + broker_url: str, + service_url: str = "", + service_url_tls: str = "", + broker_url_tls: str = "", + proxy_url: str = "", + proxy_protocol: str = "SNI", + peer_cluster_names: list = [], + ): # initialize data data = { - 'serviceUrl': service_url, - 'serviceUrlTls': service_url_tls, - 'brokerServiceUrl': broker_url, - 'brokerServiceUrlTls': broker_url_tls, - 'peerClusterNames': peer_cluster_names, - 'proxyServiceUrl': proxy_url, - 'proxyProtocol': proxy_protocol + "serviceUrl": service_url, + "serviceUrlTls": service_url_tls, + "brokerServiceUrl": broker_url, + "brokerServiceUrlTls": broker_url_tls, + "peerClusterNames": peer_cluster_names, + "proxyServiceUrl": proxy_url, + "proxyProtocol": proxy_protocol, } session = self._create_session() response = session.post( - self.service_url + CLUSTER.format(cluster_name), data=json.dumps(data)) + self.service_url + CLUSTER.format(cluster_name), data=json.dumps(data) + ) return response # tenants - def get_tenant(self, tenant: str = ''): + def get_tenant(self, tenant: str = ""): session = self._create_session() response = session.get(self.service_url + TENANT.format(tenant)) return response @@ -130,59 +147,57 @@ def get_tenant(self, tenant: str = ''): def create_tenant(self, tenant: str, admins: list, clusters: list): session = self._create_session() - data = {'adminRoles': admins, - 'allowedClusters': clusters} + data = {"adminRoles": admins, "allowedClusters": clusters} response = session.put( - self.service_url + TENANT.format(tenant), data=json.dumps(data)) + self.service_url + TENANT.format(tenant), data=json.dumps(data) + ) return response def delete_tenant(self, tenant: str): session = self._create_session() - response = session.delete( - self.service_url + TENANT.format(tenant)) + response = session.delete(self.service_url + TENANT.format(tenant)) return response def update_tenant(self, tenant: str, admins: list, clusters: list): session = self._create_session() - data = {'adminRoles': admins, - 'allowedClusters': clusters} + data = {"adminRoles": admins, "allowedClusters": clusters} response = session.post( - self.service_url + TENANT.format(tenant), data=json.dumps(data)) + self.service_url + TENANT.format(tenant), data=json.dumps(data) + ) return response # namespace def get_namespace(self, tenant: str): session = self._create_session() - response = session.get( - self.service_url + 'namespaces/{}'.format(tenant)) + response = session.get(self.service_url + "namespaces/{}".format(tenant)) return response # 'replication_clusters' is always required def create_namespace(self, tenant: str, namespace: str, policies: dict = {}): session = self._create_session() response = session.put( - self.service_url + 'namespaces/{}/{}'.format(tenant, namespace), - data=json.dumps(policies) + self.service_url + "namespaces/{}/{}".format(tenant, namespace), + data=json.dumps(policies), ) return response def delete_namespace(self, tenant: str, namespace: str): session = self._create_session() response = session.delete( - self.service_url + - 'namespaces/{}/{}'.format(tenant, namespace) + self.service_url + "namespaces/{}/{}".format(tenant, namespace) ) return response def set_clusters_to_namespace(self, tenant: str, namespace: str, clusters: list): session = self._create_session() response = session.post( - self.service_url + 'namespaces/{}/{}/replication'.format(tenant, namespace), json=clusters + self.service_url + "namespaces/{}/{}/replication".format(tenant, namespace), + json=clusters, ) return response @@ -190,16 +205,19 @@ def set_clusters_to_namespace(self, tenant: str, namespace: str, clusters: list) def get_cluster_from_namespace(self, tenant: str, namespace: str): session = self._create_session() response = session.get( - self.service_url + - 'namespaces/{}/{}/replication'.format(tenant, namespace) + self.service_url + "namespaces/{}/{}/replication".format(tenant, namespace) ) return response - def set_subscription_expiration_time(self, tenant: str, namespace: str, mintues: int = 0): + def set_subscription_expiration_time( + self, tenant: str, namespace: str, mintues: int = 0 + ): session = self._create_session() response = session.post( - self.service_url + 'namespaces/{}/{}/subscriptionExpirationTime'.format(tenant, namespace), json=mintues + self.service_url + + "namespaces/{}/{}/subscriptionExpirationTime".format(tenant, namespace), + json=mintues, ) return response @@ -208,48 +226,61 @@ def set_message_ttl(self, tenant: str, namespace: str, mintues: int = 0): session = self._create_session() response = session.post( # the API accepts data as seconds - self.service_url + 'namespaces/{}/{}/messageTTL'.format(tenant, namespace), json=mintues * 60 + self.service_url + "namespaces/{}/{}/messageTTL".format(tenant, namespace), + json=mintues * 60, ) return response - def unsubscribe_namespace_all_topics(self, tenant: str, namespace: str, subscription_name: str): + def unsubscribe_namespace_all_topics( + self, tenant: str, namespace: str, subscription_name: str + ): session = self._create_session() response = session.post( - self.service_url + - 'namespaces/{}/{}/unsubscribe/{}'.format( - tenant, namespace, subscription_name) + self.service_url + + "namespaces/{}/{}/unsubscribe/{}".format( + tenant, namespace, subscription_name + ) ) return response - def set_retention(self, tenant: str, namespace: str, - retention_time_in_minutes: int = 0, retention_size_in_MB: int = 0): + def set_retention( + self, + tenant: str, + namespace: str, + retention_time_in_minutes: int = 0, + retention_size_in_MB: int = 0, + ): session = self._create_session() - data = {'retentionTimeInMinutes': retention_time_in_minutes, - 'retentionSizeInMB': retention_size_in_MB} + data = { + "retentionTimeInMinutes": retention_time_in_minutes, + "retentionSizeInMB": retention_size_in_MB, + } response = session.post( - self.service_url + - 'namespaces/{}/{}/retention'.format(tenant, namespace), data=json.dumps(data) + self.service_url + "namespaces/{}/{}/retention".format(tenant, namespace), + data=json.dumps(data), ) return response def remove_retention(self, tenant: str, namespace: str): session = self._create_session() response = session.delete( - self.service_url + - 'namespaces/{}/{}/retention'.format(tenant, namespace), + self.service_url + "namespaces/{}/{}/retention".format(tenant, namespace), ) return response # topic - def unsubscribe_topic(self, tenant: str, namespace: str, topic: str, subscription_name: str): + def unsubscribe_topic( + self, tenant: str, namespace: str, topic: str, subscription_name: str + ): session = self._create_session() response = session.delete( - self.service_url + - 'persistent/{}/{}/{}/subscription/{}'.format( - tenant, namespace, topic, subscription_name) + self.service_url + + "persistent/{}/{}/{}/subscription/{}".format( + tenant, namespace, topic, subscription_name + ) ) return response diff --git a/python/fate/arch/federation/rabbitmq/__init__.py b/python/fate/arch/federation/rabbitmq/__init__.py index e76b8b28e6..cf04e965d6 100644 --- a/python/fate/arch/federation/rabbitmq/__init__.py +++ b/python/fate/arch/federation/rabbitmq/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # -from fate_arch.federation.rabbitmq._federation import Federation, MQ, RabbitManager +from ._federation import MQ, Federation, RabbitManager -__all__ = ['Federation', 'MQ', 'RabbitManager'] +__all__ = ["Federation", "MQ", "RabbitManager"] diff --git a/python/fate/arch/federation/rabbitmq/_federation.py b/python/fate/arch/federation/rabbitmq/_federation.py index a9c1eba0f5..c087ee5845 100644 --- a/python/fate/arch/federation/rabbitmq/_federation.py +++ b/python/fate/arch/federation/rabbitmq/_federation.py @@ -16,12 +16,11 @@ import json -from fate_arch.common import Party -from fate_arch.common import file_utils -from fate_arch.common.log import getLogger -from fate_arch.federation._federation import FederationBase -from fate_arch.federation.rabbitmq._mq_channel import MQChannel -from fate_arch.federation.rabbitmq._rabbit_manager import RabbitManager +from ...common import Party, file_utils +from ...common.log import getLogger +from ._federation import FederationBase +from ._mq_channel import MQChannel +from ._rabbit_manager import RabbitManager LOGGER = getLogger() @@ -48,7 +47,9 @@ def __repr__(self): class _TopicPair(object): - def __init__(self, tenant=None, namespace=None, vhost=None, send=None, receive=None): + def __init__( + self, tenant=None, namespace=None, vhost=None, send=None, receive=None + ): self.tenant = tenant self.namespace = namespace self.vhost = vhost @@ -59,10 +60,7 @@ def __init__(self, tenant=None, namespace=None, vhost=None, send=None, receive=N class Federation(FederationBase): @staticmethod def from_conf( - federation_session_id: str, - party: Party, - runtime_conf: dict, - **kwargs + federation_session_id: str, party: Party, runtime_conf: dict, **kwargs ): rabbitmq_config = kwargs["rabbitmq_config"] LOGGER.debug(f"rabbitmq_config: {rabbitmq_config}") @@ -73,7 +71,9 @@ def from_conf( base_password = rabbitmq_config.get("password") mode = rabbitmq_config.get("mode", "replication") # max_message_sizeï¼› - max_message_size = int(rabbitmq_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE)) + max_message_size = int( + rabbitmq_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE) + ) union_name = federation_session_id policy_id = federation_session_id @@ -81,8 +81,7 @@ def from_conf( rabbitmq_run = runtime_conf.get("job_parameters", {}).get("rabbitmq_run", {}) LOGGER.debug(f"rabbitmq_run: {rabbitmq_run}") - max_message_size = int(rabbitmq_run.get( - "max_message_size", max_message_size)) + max_message_size = int(rabbitmq_run.get("max_message_size", max_message_size)) LOGGER.debug(f"set max message size to {max_message_size} Bytes") @@ -95,16 +94,35 @@ def from_conf( route_table_path = "conf/rabbitmq_route_table.yaml" route_table = file_utils.load_yaml_conf(conf_path=route_table_path) mq = MQ(host, port, union_name, policy_id, route_table) - conf = rabbit_manager.runtime_config.get( - "connection", {} - ) + conf = rabbit_manager.runtime_config.get("connection", {}) return Federation( - federation_session_id, party, mq, rabbit_manager, max_message_size, conf, mode + federation_session_id, + party, + mq, + rabbit_manager, + max_message_size, + conf, + mode, ) - def __init__(self, session_id, party: Party, mq: MQ, rabbit_manager: RabbitManager, max_message_size, conf, mode): - super().__init__(session_id=session_id, party=party, mq=mq, max_message_size=max_message_size, conf=conf) + def __init__( + self, + session_id, + party: Party, + mq: MQ, + rabbit_manager: RabbitManager, + max_message_size, + conf, + mode, + ): + super().__init__( + session_id=session_id, + party=party, + mq=mq, + max_message_size=max_message_size, + conf=conf, + ) self._rabbit_manager = rabbit_manager self._vhost_set = set() self._mode = mode @@ -153,7 +171,7 @@ def _create_topic_by_client_mode(self, party, topic_suffix): namespace=self._session_id, vhost=vhost_name, send=send_queue_name, - receive=receive_queue_name + receive=receive_queue_name, ) # initial vhost @@ -181,7 +199,7 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): namespace=self._session_id, vhost=vhost_name, send=send_queue_name, - receive=receive_queue_name + receive=receive_queue_name, ) # initial vhost @@ -196,9 +214,7 @@ def _create_topic_by_replication_mode(self, party, topic_suffix): self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.send) # initial receive queue, the name is receive-${vhost} - self._rabbit_manager.create_queue( - topic_pair.vhost, topic_pair.receive - ) + self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.receive) upstream_uri = self._upstream_uri(party_id=party.party_id) self._rabbit_manager.federate_queue( @@ -219,9 +235,19 @@ def _upstream_uri(self, party_id): return upstream_uri def _get_channel( - self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None): - LOGGER.debug(f"rabbitmq federation _get_channel, src_party_id={src_party_id}, src_role={src_role}," - f"dst_party_id={dst_party_id}, dst_role={dst_role}") + self, + topic_pair, + src_party_id, + src_role, + dst_party_id, + dst_role, + mq=None, + conf: dict = None, + ): + LOGGER.debug( + f"rabbitmq federation _get_channel, src_party_id={src_party_id}, src_role={src_role}," + f"dst_party_id={dst_party_id}, dst_role={dst_role}" + ) return MQChannel( host=mq.host, port=mq.port, @@ -248,7 +274,7 @@ def _get_consume_message(self, channel_info): "message_id": properties.message_id, "correlation_id": properties.correlation_id, "content_type": properties.content_type, - "headers": json.dumps(properties.headers) + "headers": json.dumps(properties.headers), } yield method.delivery_tag, properties, body diff --git a/python/fate/arch/federation/rabbitmq/_mq_channel.py b/python/fate/arch/federation/rabbitmq/_mq_channel.py index ef4e9800dd..8da8e15a69 100644 --- a/python/fate/arch/federation/rabbitmq/_mq_channel.py +++ b/python/fate/arch/federation/rabbitmq/_mq_channel.py @@ -15,30 +15,32 @@ # import json + import pika -from fate_arch.common import log -from fate_arch.federation._nretry import nretry +from ...common import log +from .._nretry import nretry LOGGER = log.getLogger() class MQChannel(object): - - def __init__(self, - host, - port, - user, - password, - namespace, - vhost, - send_queue_name, - receive_queue_name, - src_party_id, - src_role, - dst_party_id, - dst_role, - extra_args: dict): + def __init__( + self, + host, + port, + user, + password, + namespace, + vhost, + send_queue_name, + receive_queue_name, + src_party_id, + src_role, + dst_party_id, + dst_role, + extra_args: dict, + ): self._host = host self._port = port self._credentials = pika.PlainCredentials(user, password) @@ -87,8 +89,12 @@ def produce(self, body, properties: dict): delivery_mode=1, ) - return self._channel.basic_publish(exchange='', routing_key=self._send_queue_name, body=body, - properties=properties) + return self._channel.basic_publish( + exchange="", + routing_key=self._send_queue_name, + body=body, + properties=properties, + ) @nretry def consume(self): @@ -113,10 +119,15 @@ def _get_channel(self): self._clear() if not self._conn: - self._conn = pika.BlockingConnection(pika.ConnectionParameters(host=self._host, port=self._port, - virtual_host=self._vhost, - credentials=self._credentials, - **self._extra_args)) + self._conn = pika.BlockingConnection( + pika.ConnectionParameters( + host=self._host, + port=self._port, + virtual_host=self._vhost, + credentials=self._credentials, + **self._extra_args, + ) + ) if not self._channel: self._channel = self._conn.channel() @@ -137,4 +148,9 @@ def _clear(self): self._channel = None def _check_alive(self): - return self._channel and self._channel.is_open and self._conn and self._conn.is_open + return ( + self._channel + and self._channel.is_open + and self._conn + and self._conn.is_open + ) diff --git a/python/fate/arch/federation/rabbitmq/_rabbit_manager.py b/python/fate/arch/federation/rabbitmq/_rabbit_manager.py index e84bfdc5e8..fd3257180e 100644 --- a/python/fate/arch/federation/rabbitmq/_rabbit_manager.py +++ b/python/fate/arch/federation/rabbitmq/_rabbit_manager.py @@ -14,15 +14,16 @@ # limitations under the License. # -import requests import time -from fate_arch.common import log +import requests + +from ...common import log LOGGER = log.getLogger() C_HTTP_TEMPLATE = "http://{}/api/{}" -C_COMMON_HTTP_HEADER = {'Content-Type': 'application/json'} +C_COMMON_HTTP_HEADER = {"Content-Type": "application/json"} """ APIs are refered to https://rawcdn.githack.com/rabbitmq/rabbitmq-management/v3.8.3/priv/www/api/index.html @@ -39,12 +40,13 @@ def __init__(self, user, password, endpoint, runtime_config=None): def create_user(self, user, password): url = C_HTTP_TEMPLATE.format(self.endpoint, "users/" + user) - body = { - "password": password, - "tags": "" - } - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=body, auth=(self.user, self.password)) + body = {"password": password, "tags": ""} + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=body, + auth=(self.user, self.password), + ) LOGGER.debug(f"[rabbitmanager.create_user] {result}") if result.status_code == 201 or result.status_code == 204: return True @@ -60,7 +62,8 @@ def delete_user(self, user): def create_vhost(self, vhost): url = C_HTTP_TEMPLATE.format(self.endpoint, "vhosts/" + vhost) result = requests.put( - url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password)) + url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password) + ) LOGGER.debug(f"[rabbitmanager.create_vhost] {result}") self.add_user_to_vhost(self.user, vhost) return True @@ -93,15 +96,16 @@ def get_vhosts(self): def add_user_to_vhost(self, user, vhost): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("permissions", vhost, user)) - body = { - "configure": ".*", - "write": ".*", - "read": ".*" - } + self.endpoint, "{}/{}/{}".format("permissions", vhost, user) + ) + body = {"configure": ".*", "write": ".*", "read": ".*"} - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=body, auth=(self.user, self.password)) + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=body, + auth=(self.user, self.password), + ) LOGGER.debug(f"[rabbitmanager.add_user_to_vhost] {result}") if result.status_code == 201 or result.status_code == 204: @@ -111,7 +115,8 @@ def add_user_to_vhost(self, user, vhost): def remove_user_from_vhost(self, user, vhost): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("permissions", vhost, user)) + self.endpoint, "{}/{}/{}".format("permissions", vhost, user) + ) result = requests.delete(url, auth=(self.user, self.password)) LOGGER.debug(f"[rabbitmanager.remove_user_from_vhost] {result}") return result @@ -123,7 +128,9 @@ def get_exchanges(self, vhost): try: if result.status_code == 200: exchange_names = [e["name"] for e in result.json()] - LOGGER.debug(f"[rabbitmanager.get_exchanges] exchange_names={exchange_names}") + LOGGER.debug( + f"[rabbitmanager.get_exchanges] exchange_names={exchange_names}" + ) return exchange_names else: return None @@ -132,29 +139,37 @@ def get_exchanges(self, vhost): def create_exchange(self, vhost, exchange_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("exchanges", vhost, exchange_name)) + self.endpoint, "{}/{}/{}".format("exchanges", vhost, exchange_name) + ) basic_config = { "type": "direct", "auto_delete": False, "durable": True, "internal": False, - "arguments": {} + "arguments": {}, } exchange_runtime_config = self.runtime_config.get("exchange", {}) basic_config.update(exchange_runtime_config) - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=basic_config, auth=(self.user, self.password)) + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=basic_config, + auth=(self.user, self.password), + ) LOGGER.debug(result) return result def delete_exchange(self, vhost, exchange_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("exchanges", vhost, exchange_name)) + self.endpoint, "{}/{}/{}".format("exchanges", vhost, exchange_name) + ) result = requests.delete(url, auth=(self.user, self.password)) - LOGGER.debug(f"[rabbitmanager.delete_exchange] vhost={vhost}, exchange_name={exchange_name}, {result}") + LOGGER.debug( + f"[rabbitmanager.delete_exchange] vhost={vhost}, exchange_name={exchange_name}, {result}" + ) return result def get_policies(self, vhost): @@ -164,7 +179,9 @@ def get_policies(self, vhost): try: if result.status_code == 200: policies_names = [e["name"] for e in result.json()] - LOGGER.debug(f"[rabbitmanager.get_policies] policies_names={policies_names}") + LOGGER.debug( + f"[rabbitmanager.get_policies] policies_names={policies_names}" + ) return policies_names else: return None @@ -173,27 +190,31 @@ def get_policies(self, vhost): def delete_policy(self, vhost, policy_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("policies", vhost, policy_name)) + self.endpoint, "{}/{}/{}".format("policies", vhost, policy_name) + ) result = requests.delete(url, auth=(self.user, self.password)) - LOGGER.debug(f"[rabbitmanager.delete_policy] vhost={vhost}, policy_name={policy_name}, {result}") + LOGGER.debug( + f"[rabbitmanager.delete_policy] vhost={vhost}, policy_name={policy_name}, {result}" + ) return result def create_queue(self, vhost, queue_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name)) + self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name) + ) - basic_config = { - "auto_delete": False, - "durable": True, - "arguments": {} - } + basic_config = {"auto_delete": False, "durable": True, "arguments": {}} queue_runtime_config = self.runtime_config.get("queue", {}) basic_config.update(queue_runtime_config) LOGGER.debug(basic_config) - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=basic_config, auth=(self.user, self.password)) + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=basic_config, + auth=(self.user, self.password), + ) LOGGER.debug(f"[rabbitmanager.create_queue] {result}") if result.status_code == 201 or result.status_code == 204: @@ -203,15 +224,19 @@ def create_queue(self, vhost, queue_name): def get_queue(self, vhost, queue_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name)) + self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name) + ) - result = requests.get(url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password)) + result = requests.get( + url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password) + ) return result def get_queues(self, vhost): - url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}".format("queues", vhost)) - result = requests.get(url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password)) + url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}".format("queues", vhost)) + result = requests.get( + url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password) + ) try: if result.status_code == 200: queue_names = [e["name"] for e in result.json()] @@ -224,15 +249,19 @@ def get_queues(self, vhost): def delete_queue(self, vhost, queue_name): url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name)) + self.endpoint, "{}/{}/{}".format("queues", vhost, queue_name) + ) result = requests.delete(url, auth=(self.user, self.password)) - LOGGER.debug(f"[rabbitmanager.delete_queue] vhost={vhost}, queue_name={queue_name}, {result}") + LOGGER.debug( + f"[rabbitmanager.delete_queue] vhost={vhost}, queue_name={queue_name}, {result}" + ) return result def get_connections(self): - url = C_HTTP_TEMPLATE.format( - self.endpoint, "connections") - result = requests.get(url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password)) + url = C_HTTP_TEMPLATE.format(self.endpoint, "connections") + result = requests.get( + url, headers=C_COMMON_HTTP_HEADER, auth=(self.user, self.password) + ) LOGGER.debug(f"[rabbitmanager.get_connections] {result}") return result @@ -252,99 +281,118 @@ def delete_connections(self, vhost=None): LOGGER.debug("[rabbitmanager.delete_connections] start....") for name in names: url = C_HTTP_TEMPLATE.format( - self.endpoint, "{}/{}".format("connections", name)) + self.endpoint, "{}/{}".format("connections", name) + ) result = requests.delete(url, auth=(self.user, self.password)) LOGGER.debug(result) def bind_exchange_to_queue(self, vhost, exchange_name, queue_name): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/e/{}/q/{}".format("bindings", - vhost, - exchange_name, - queue_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, + "{}/{}/e/{}/q/{}".format("bindings", vhost, exchange_name, queue_name), + ) - body = { - "routing_key": queue_name, - "arguments": {} - } + body = {"routing_key": queue_name, "arguments": {}} result = requests.post( - url, headers=C_COMMON_HTTP_HEADER, json=body, auth=(self.user, self.password)) + url, + headers=C_COMMON_HTTP_HEADER, + json=body, + auth=(self.user, self.password), + ) LOGGER.debug(result) return result def unbind_exchange_to_queue(self, vhost, exchange_name, queue_name): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/e/{}/q/{}/{}".format("bindings", - vhost, - exchange_name, - queue_name, - queue_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, + "{}/{}/e/{}/q/{}/{}".format( + "bindings", vhost, exchange_name, queue_name, queue_name + ), + ) result = requests.delete(url, auth=(self.user, self.password)) LOGGER.debug(result) return result def _set_federated_upstream(self, upstream_host, vhost, receive_queue_name): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/{}/{}".format("parameters", - "federation-upstream", - vhost, - receive_queue_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, + "{}/{}/{}/{}".format( + "parameters", "federation-upstream", vhost, receive_queue_name + ), + ) upstream_runtime_config = self.runtime_config.get("upstream", {}) - upstream_runtime_config['uri'] = upstream_host - upstream_runtime_config['queue'] = receive_queue_name.replace( - "receive", "send", 1) + upstream_runtime_config["uri"] = upstream_host + upstream_runtime_config["queue"] = receive_queue_name.replace( + "receive", "send", 1 + ) - body = { - "value": upstream_runtime_config - } - LOGGER.debug(f"[rabbitmanager._set_federated_upstream]set_federated_upstream, url: {url} body: {body}") + body = {"value": upstream_runtime_config} + LOGGER.debug( + f"[rabbitmanager._set_federated_upstream]set_federated_upstream, url: {url} body: {body}" + ) - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=body, auth=(self.user, self.password)) + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=body, + auth=(self.user, self.password), + ) LOGGER.debug(f"[rabbitmanager._set_federated_upstream] {result}") if result.status_code != 201 and result.status_code != 204: - LOGGER.debug(f"[rabbitmanager._set_federated_upstream] _set_federated_upstream fail. {result}") + LOGGER.debug( + f"[rabbitmanager._set_federated_upstream] _set_federated_upstream fail. {result}" + ) return False return True def _unset_federated_upstream(self, upstream_name, vhost): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/{}/{}".format("parameters", - "federation-upstream", - vhost, - upstream_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, + "{}/{}/{}/{}".format( + "parameters", "federation-upstream", vhost, upstream_name + ), + ) result = requests.delete(url, auth=(self.user, self.password)) LOGGER.debug(result) return result def _set_federated_queue_policy(self, vhost, receive_queue_name): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/{}".format("policies", - vhost, - receive_queue_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, "{}/{}/{}".format("policies", vhost, receive_queue_name) + ) body = { - "pattern": '^' + receive_queue_name + '$', + "pattern": "^" + receive_queue_name + "$", "apply-to": "queues", - "definition": - { - "federation-upstream": receive_queue_name - } + "definition": {"federation-upstream": receive_queue_name}, } - LOGGER.debug(f"[rabbitmanager._set_federated_queue_policy]set_federated_queue_policy, url: {url} body: {body}") + LOGGER.debug( + f"[rabbitmanager._set_federated_queue_policy]set_federated_queue_policy, url: {url} body: {body}" + ) - result = requests.put(url, headers=C_COMMON_HTTP_HEADER, - json=body, auth=(self.user, self.password)) + result = requests.put( + url, + headers=C_COMMON_HTTP_HEADER, + json=body, + auth=(self.user, self.password), + ) LOGGER.debug(f"[rabbitmanager._set_federated_queue_policy] {result}") if result.status_code != 201 and result.status_code != 204: - LOGGER.debug(f"[rabbitmanager._set_federated_queue_policy] _set_federated_queue_policy fail. {result}") + LOGGER.debug( + f"[rabbitmanager._set_federated_queue_policy] _set_federated_queue_policy fail. {result}" + ) return False return True def _unset_federated_queue_policy(self, policy_name, vhost): - url = C_HTTP_TEMPLATE.format(self.endpoint, "{}/{}/{}".format("policies", - vhost, - policy_name)) + url = C_HTTP_TEMPLATE.format( + self.endpoint, "{}/{}/{}".format("policies", vhost, policy_name) + ) result = requests.delete(url, auth=(self.user, self.password)) LOGGER.debug(result) @@ -353,18 +401,18 @@ def _unset_federated_queue_policy(self, policy_name, vhost): # Create federate queue with upstream def federate_queue(self, upstream_host, vhost, send_queue_name, receive_queue_name): time.sleep(0.1) - LOGGER.debug(f"[rabbitmanager.federate_queue] create federate_queue {receive_queue_name}") + LOGGER.debug( + f"[rabbitmanager.federate_queue] create federate_queue {receive_queue_name}" + ) - result = self._set_federated_upstream( - upstream_host, vhost, receive_queue_name) + result = self._set_federated_upstream(upstream_host, vhost, receive_queue_name) if result is False: # should be logged LOGGER.debug(f"[rabbitmanager.federate_queue] result_set_upstream fail.") return False - result = self._set_federated_queue_policy( - vhost, receive_queue_name) + result = self._set_federated_queue_policy(vhost, receive_queue_name) if result is False: LOGGER.debug(f"[rabbitmanager.federate_queue] result_set_policy fail.") @@ -374,12 +422,12 @@ def federate_queue(self, upstream_host, vhost, send_queue_name, receive_queue_na def de_federate_queue(self, vhost, receive_queue_name): result = self._unset_federated_queue_policy(receive_queue_name, vhost) - LOGGER.debug( - f"delete federate queue policy status code: {result.status_code}") + LOGGER.debug(f"delete federate queue policy status code: {result.status_code}") result = self._unset_federated_upstream(receive_queue_name, vhost) LOGGER.debug( - f"delete federate queue upstream status code: {result.status_code}") + f"delete federate queue upstream status code: {result.status_code}" + ) return True diff --git a/python/fate/arch/federation/standalone/__init__.py b/python/fate/arch/federation/standalone/__init__.py index a33500ff58..4982a8a2db 100644 --- a/python/fate/arch/federation/standalone/__init__.py +++ b/python/fate/arch/federation/standalone/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # -from fate_arch.federation.standalone._federation import Federation +from ._federation import Federation -__all__ = ['Federation'] +__all__ = ["Federation"] diff --git a/python/fate/arch/federation/standalone/_federation.py b/python/fate/arch/federation/standalone/_federation.py index be33a1a272..815ba6b959 100644 --- a/python/fate/arch/federation/standalone/_federation.py +++ b/python/fate/arch/federation/standalone/_federation.py @@ -1,10 +1,10 @@ import typing -from fate_arch._standalone import Federation as RawFederation, Table as RawTable -from fate_arch.abc import FederationABC -from fate_arch.abc import GarbageCollectionABC -from fate_arch.common import Party, log -from fate_arch.computing.standalone import Table +from ..._standalone import Federation as RawFederation +from ..._standalone import Table as RawTable +from ...abc import FederationABC, GarbageCollectionABC +from ...common import Party, log +from ...computing.standalone import Table LOGGER = log.getLogger() diff --git a/python/fate/arch/federation/transfer_variable.py b/python/fate/arch/federation/transfer_variable.py index f8c528b594..4f4e66ed6e 100644 --- a/python/fate/arch/federation/transfer_variable.py +++ b/python/fate/arch/federation/transfer_variable.py @@ -18,9 +18,9 @@ import typing from typing import Union -from fate_arch.common import Party, profile -from fate_arch.common.log import getLogger -from fate_arch.federation._gc import IterationGC +from ..common import Party, profile +from ..common.log import getLogger +from ..federation._gc import IterationGC __all__ = ["Variable", "BaseTransferVariables"] @@ -126,7 +126,7 @@ def remote_parties( ------- None """ - from fate_arch.session import get_session + from ..session import get_session session = get_session() if isinstance(parties, Party): @@ -177,7 +177,7 @@ def get_parties( a list of objects/tables get from parties with same order of ``parties`` """ - from fate_arch.session import get_session + from ..session import get_session session = get_session() if not isinstance(parties, list): @@ -220,7 +220,7 @@ def remote(self, obj, role=None, idx=-1, suffix=tuple()): The default is -1, which means sent values to parties regardless their party id suffix: additional tag suffix, the default is tuple() """ - from fate_arch.session import get_parties + from ..session import get_parties party_info = get_parties() if idx >= 0 and role is None: @@ -254,7 +254,7 @@ def get(self, idx=-1, role=None, suffix=tuple()): Returns: object or list of object """ - from fate_arch.session import get_parties + from ..session import get_parties if role is None: src_parties = get_parties().roles_to_parties(roles=self._src, strict=False) @@ -326,7 +326,7 @@ def all_parties(): list of parties """ - from fate_arch.session import get_parties + from ..session import get_parties return get_parties().all_parties @@ -341,6 +341,6 @@ def local_party(): party this program running on """ - from fate_arch.session import get_parties + from ..session import get_parties return get_parties().local_party diff --git a/python/fate/arch/metastore/base_model.py b/python/fate/arch/metastore/base_model.py index a7f61494d1..740e00b935 100644 --- a/python/fate/arch/metastore/base_model.py +++ b/python/fate/arch/metastore/base_model.py @@ -17,21 +17,47 @@ import typing from enum import IntEnum -from peewee import Field, IntegerField, FloatField, BigIntegerField, TextField, Model, CompositeKey, Metadata - -from fate_arch.common import conf_utils, EngineType -from fate_arch.common.base_utils import current_timestamp, serialize_b64, deserialize_b64, timestamp_to_date, date_string_to_timestamp, json_dumps, json_loads -from fate_arch.federation import FederationEngine - -is_standalone = conf_utils.get_base_config("default_engines", {}).get( - EngineType.FEDERATION).upper() == FederationEngine.STANDALONE +from peewee import ( + BigIntegerField, + CompositeKey, + Field, + FloatField, + IntegerField, + Metadata, + Model, + TextField, +) + +from ..common import EngineType, conf_utils +from ..common.base_utils import ( + current_timestamp, + date_string_to_timestamp, + deserialize_b64, + json_dumps, + json_loads, + serialize_b64, + timestamp_to_date, +) +from ..federation import FederationEngine + +is_standalone = ( + conf_utils.get_base_config("default_engines", {}).get(EngineType.FEDERATION).upper() + == FederationEngine.STANDALONE +) if is_standalone: from playhouse.apsw_ext import DateTimeField else: from peewee import DateTimeField CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} -AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} +AUTO_DATE_TIMESTAMP_FIELD_PREFIX = { + "create", + "start", + "end", + "update", + "read_access", + "write_access", +} class SerializedType(IntEnum): @@ -40,7 +66,7 @@ class SerializedType(IntEnum): class LongTextField(TextField): - field_type = 'LONGTEXT' + field_type = "LONGTEXT" class JSONField(LongTextField): @@ -59,7 +85,11 @@ def db_value(self, value): def python_value(self, value): if not value: return self.default_value - return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return json_loads( + value, + object_hook=self._object_hook, + object_pairs_hook=self._object_pairs_hook, + ) class ListField(JSONField): @@ -67,7 +97,13 @@ class ListField(JSONField): class SerializedField(LongTextField): - def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): + def __init__( + self, + serialized_type=SerializedType.PICKLE, + object_hook=None, + object_pairs_hook=None, + **kwargs, + ): self._serialized_type = serialized_type self._object_hook = object_hook self._object_pairs_hook = object_pairs_hook @@ -81,7 +117,9 @@ def db_value(self, value): return None return json_dumps(value, with_type=True) else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") + raise ValueError( + f"the serialized type {self._serialized_type} is not supported" + ) def python_value(self, value): if self._serialized_type == SerializedType.PICKLE: @@ -89,9 +127,15 @@ def python_value(self, value): elif self._serialized_type == SerializedType.JSON: if value is None: return {} - return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) + return json_loads( + value, + object_hook=self._object_hook, + object_pairs_hook=self._object_pairs_hook, + ) else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") + raise ValueError( + f"the serialized type {self._serialized_type} is not supported" + ) def is_continuous_field(cls: typing.Type) -> bool: @@ -116,7 +160,7 @@ def auto_date_timestamp_db_field(): def remove_field_name_prefix(field_name): - return field_name[2:] if field_name.startswith('f_') else field_name + return field_name[2:] if field_name.startswith("f_") else field_name class BaseModel(Model): @@ -130,10 +174,10 @@ def to_json(self): return self.to_dict() def to_dict(self): - return self.__dict__['__data__'] + return self.__dict__["__data__"] def to_human_model_dict(self, only_primary_with: list = None): - model_dict = self.__dict__['__data__'] + model_dict = self.__dict__["__data__"] if not only_primary_with: return {remove_field_name_prefix(k): v for k, v in model_dict.items()} @@ -142,7 +186,7 @@ def to_human_model_dict(self, only_primary_with: list = None): for k in self._meta.primary_key.field_names: human_model_dict[remove_field_name_prefix(k)] = model_dict[k] for k in only_primary_with: - human_model_dict[k] = model_dict[f'f_{k}'] + human_model_dict[k] = model_dict[f"f_{k}"] return human_model_dict @property @@ -151,8 +195,11 @@ def meta(self) -> Metadata: @classmethod def get_primary_keys_name(cls): - return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [ - cls._meta.primary_key.name] + return ( + cls._meta.primary_key.field_names + if isinstance(cls._meta.primary_key, CompositeKey) + else [cls._meta.primary_key.name] + ) @classmethod def getter_by(cls, attr): @@ -162,7 +209,7 @@ def getter_by(cls, attr): def query(cls, reverse=None, order_by=None, **kwargs): filters = [] for f_n, f_v in kwargs.items(): - attr_name = 'f_%s' % f_n + attr_name = "f_%s" % f_n if not hasattr(cls, attr_name) or f_v is None: continue if type(f_v) in {list, set}: @@ -170,17 +217,26 @@ def query(cls, reverse=None, order_by=None, **kwargs): if is_continuous_field(type(getattr(cls, attr_name))): if len(f_v) == 2: for i, v in enumerate(f_v): - if isinstance(v, str) and f_n in auto_date_timestamp_field(): + if ( + isinstance(v, str) + and f_n in auto_date_timestamp_field() + ): # time type: %Y-%m-%d %H:%M:%S f_v[i] = date_string_to_timestamp(v) lt_value = f_v[0] gt_value = f_v[1] if lt_value is not None and gt_value is not None: - filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) + filters.append( + cls.getter_by(attr_name).between(lt_value, gt_value) + ) elif lt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) + filters.append( + operator.attrgetter(attr_name)(cls) >= lt_value + ) elif gt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) + filters.append( + operator.attrgetter(attr_name)(cls) <= gt_value + ) else: filters.append(operator.attrgetter(attr_name)(cls) << f_v) else: @@ -191,9 +247,13 @@ def query(cls, reverse=None, order_by=None, **kwargs): if not order_by or not hasattr(cls, f"f_{order_by}"): order_by = "create_time" if reverse is True: - query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").desc()) + query_records = query_records.order_by( + cls.getter_by(f"f_{order_by}").desc() + ) elif reverse is False: - query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").asc()) + query_records = query_records.order_by( + cls.getter_by(f"f_{order_by}").asc() + ) return [query_record for query_record in query_records] else: return [] @@ -217,10 +277,13 @@ def _normalize_data(cls, data, kwargs): normalized[cls._meta.combined["f_update_time"]] = current_timestamp() for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: - if {f"f_{f_n}_time", f"f_{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ - cls._meta.combined[f"f_{f_n}_time"] in normalized and \ - normalized[cls._meta.combined[f"f_{f_n}_time"]] is not None: + if ( + {f"f_{f_n}_time", f"f_{f_n}_date"}.issubset(cls._meta.combined.keys()) + and cls._meta.combined[f"f_{f_n}_time"] in normalized + and normalized[cls._meta.combined[f"f_{f_n}_time"]] is not None + ): normalized[cls._meta.combined[f"f_{f_n}_date"]] = timestamp_to_date( - normalized[cls._meta.combined[f"f_{f_n}_time"]]) + normalized[cls._meta.combined[f"f_{f_n}_time"]] + ) return normalized diff --git a/python/fate/arch/metastore/db_models.py b/python/fate/arch/metastore/db_models.py index d6dfab89b6..d9bf097406 100644 --- a/python/fate/arch/metastore/db_models.py +++ b/python/fate/arch/metastore/db_models.py @@ -17,20 +17,27 @@ import os import sys -from peewee import CharField, IntegerField, BigIntegerField, TextField, CompositeKey, BooleanField - -from fate_arch.federation import FederationEngine -from fate_arch.metastore.base_model import DateTimeField -from fate_arch.common import file_utils, log, EngineType, conf_utils -from fate_arch.common.conf_utils import decrypt_database_config -from fate_arch.metastore.base_model import JSONField, SerializedField, BaseModel - +from peewee import ( + BigIntegerField, + BooleanField, + CharField, + CompositeKey, + IntegerField, + TextField, +) + +from ..common import EngineType, conf_utils, file_utils, log +from ..common.conf_utils import decrypt_database_config +from ..federation import FederationEngine +from ..metastore.base_model import BaseModel, DateTimeField, JSONField, SerializedField LOGGER = log.getLogger() DATABASE = decrypt_database_config() -is_standalone = conf_utils.get_base_config("default_engines", {}).get(EngineType.FEDERATION).upper() == \ - FederationEngine.STANDALONE +is_standalone = ( + conf_utils.get_base_config("default_engines", {}).get(EngineType.FEDERATION).upper() + == FederationEngine.STANDALONE +) def singleton(cls, *args, **kw): @@ -52,9 +59,13 @@ def __init__(self): db_name = database_config.pop("name") if is_standalone and not bool(int(os.environ.get("FORCE_USE_MYSQL", 0))): from playhouse.apsw_ext import APSWDatabase - self.database_connection = APSWDatabase(file_utils.get_project_base_directory("fate_sqlite.db")) + + self.database_connection = APSWDatabase( + file_utils.get_project_base_directory("fate_sqlite.db") + ) else: from playhouse.pool import PooledMySQLDatabase + self.database_connection = PooledMySQLDatabase(db_name, **database_config) @@ -121,9 +132,9 @@ class StorageTableMetaModel(DataBaseModel): f_schema = SerializedField() f_count = BigIntegerField(null=True) f_part_of_data = SerializedField() - f_origin = CharField(max_length=50, default='') + f_origin = CharField(max_length=50, default="") f_disable = BooleanField(default=False) - f_description = TextField(default='') + f_description = TextField(default="") f_read_access_time = BigIntegerField(null=True) f_read_access_date = DateTimeField(null=True) @@ -132,7 +143,7 @@ class StorageTableMetaModel(DataBaseModel): class Meta: db_table = "t_storage_table_meta" - primary_key = CompositeKey('f_name', 'f_namespace') + primary_key = CompositeKey("f_name", "f_namespace") class SessionRecord(DataBaseModel): diff --git a/python/fate/arch/metastore/db_utils.py b/python/fate/arch/metastore/db_utils.py index b9e93f22c6..558aa4a716 100644 --- a/python/fate/arch/metastore/db_utils.py +++ b/python/fate/arch/metastore/db_utils.py @@ -1,7 +1,7 @@ import operator -from fate_arch.common.base_utils import current_timestamp -from fate_arch.metastore.db_models import DB, StorageConnectorModel +from ..common.base_utils import current_timestamp +from ..metastore.db_models import DB, StorageConnectorModel class StorageConnector: @@ -17,11 +17,10 @@ def create_or_update(self): "f_engine": self.engine, "f_connector_info": self.connector_info, "f_create_time": current_timestamp(), - } connector, status = StorageConnectorModel.get_or_create( - f_name=self.name, - defaults=defaults) + f_name=self.name, defaults=defaults + ) if status is False: for key in defaults: setattr(connector, key, defaults[key]) @@ -29,8 +28,12 @@ def create_or_update(self): @DB.connection_context() def get_info(self): - connectors = [connector for connector in StorageConnectorModel.select().where( - operator.attrgetter("f_name")(StorageConnectorModel) == self.name)] + connectors = [ + connector + for connector in StorageConnectorModel.select().where( + operator.attrgetter("f_name")(StorageConnectorModel) == self.name + ) + ] if connectors: return connectors[0].f_connector_info else: diff --git a/python/fate/arch/protobuf/__init__.py b/python/fate/arch/protobuf/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate/arch/protobuf/python/__init__.py b/python/fate/arch/protobuf/python/__init__.py index cf52e67516..7007f3d365 100644 --- a/python/fate/arch/protobuf/python/__init__.py +++ b/python/fate/arch/protobuf/python/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # -import sys import os +import sys sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) diff --git a/python/fate/arch/protobuf/python/basic_meta_pb2.py b/python/fate/arch/protobuf/python/basic_meta_pb2.py index 6c2d589cf1..7c909e35fd 100644 --- a/python/fate/arch/protobuf/python/basic_meta_pb2.py +++ b/python/fate/arch/protobuf/python/basic_meta_pb2.py @@ -7,141 +7,189 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x62\x61sic-meta.proto\x12\x1e\x63om.webank.ai.eggroll.api.core\"6\n\x08\x45ndpoint\x12\n\n\x02ip\x18\x01 \x01(\t\x12\x0c\n\x04port\x18\x02 \x01(\x05\x12\x10\n\x08hostname\x18\x03 \x01(\t\"H\n\tEndpoints\x12;\n\tendpoints\x18\x01 \x03(\x0b\x32(.com.webank.ai.eggroll.api.core.Endpoint\"H\n\x04\x44\x61ta\x12\x0e\n\x06isNull\x18\x01 \x01(\x08\x12\x14\n\x0chostLanguage\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"F\n\x0cRepeatedData\x12\x36\n\x08\x64\x61talist\x18\x01 \x03(\x0b\x32$.com.webank.ai.eggroll.api.core.Data\"u\n\x0b\x43\x61llRequest\x12\x0f\n\x07isAsync\x18\x01 \x01(\x08\x12\x0f\n\x07timeout\x18\x02 \x01(\x03\x12\x0f\n\x07\x63ommand\x18\x03 \x01(\t\x12\x33\n\x05param\x18\x04 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Data\"\x88\x01\n\x0c\x43\x61llResponse\x12\x42\n\x0creturnStatus\x18\x01 \x01(\x0b\x32,.com.webank.ai.eggroll.api.core.ReturnStatus\x12\x34\n\x06result\x18\x02 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Data\"\"\n\x03Job\x12\r\n\x05jobId\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\"Y\n\x04Task\x12\x30\n\x03job\x18\x01 \x01(\x0b\x32#.com.webank.ai.eggroll.api.core.Job\x12\x0e\n\x06taskId\x18\x02 \x01(\x03\x12\x0f\n\x07tableId\x18\x03 \x01(\x03\"N\n\x06Result\x12\x32\n\x04task\x18\x01 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Task\x12\x10\n\x08resultId\x18\x02 \x01(\x03\"-\n\x0cReturnStatus\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xe2\x01\n\x0bSessionInfo\x12\x11\n\tsessionId\x18\x01 \x01(\t\x12\x61\n\x13\x63omputingEngineConf\x18\x02 \x03(\x0b\x32\x44.com.webank.ai.eggroll.api.core.SessionInfo.ComputingEngineConfEntry\x12\x14\n\x0cnamingPolicy\x18\x03 \x01(\t\x12\x0b\n\x03tag\x18\x04 \x01(\t\x1a:\n\x18\x43omputingEngineConfEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3') - - - -_ENDPOINT = DESCRIPTOR.message_types_by_name['Endpoint'] -_ENDPOINTS = DESCRIPTOR.message_types_by_name['Endpoints'] -_DATA = DESCRIPTOR.message_types_by_name['Data'] -_REPEATEDDATA = DESCRIPTOR.message_types_by_name['RepeatedData'] -_CALLREQUEST = DESCRIPTOR.message_types_by_name['CallRequest'] -_CALLRESPONSE = DESCRIPTOR.message_types_by_name['CallResponse'] -_JOB = DESCRIPTOR.message_types_by_name['Job'] -_TASK = DESCRIPTOR.message_types_by_name['Task'] -_RESULT = DESCRIPTOR.message_types_by_name['Result'] -_RETURNSTATUS = DESCRIPTOR.message_types_by_name['ReturnStatus'] -_SESSIONINFO = DESCRIPTOR.message_types_by_name['SessionInfo'] -_SESSIONINFO_COMPUTINGENGINECONFENTRY = _SESSIONINFO.nested_types_by_name['ComputingEngineConfEntry'] -Endpoint = _reflection.GeneratedProtocolMessageType('Endpoint', (_message.Message,), { - 'DESCRIPTOR' : _ENDPOINT, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Endpoint) - }) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x10\x62\x61sic-meta.proto\x12\x1e\x63om.webank.ai.eggroll.api.core"6\n\x08\x45ndpoint\x12\n\n\x02ip\x18\x01 \x01(\t\x12\x0c\n\x04port\x18\x02 \x01(\x05\x12\x10\n\x08hostname\x18\x03 \x01(\t"H\n\tEndpoints\x12;\n\tendpoints\x18\x01 \x03(\x0b\x32(.com.webank.ai.eggroll.api.core.Endpoint"H\n\x04\x44\x61ta\x12\x0e\n\x06isNull\x18\x01 \x01(\x08\x12\x14\n\x0chostLanguage\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c"F\n\x0cRepeatedData\x12\x36\n\x08\x64\x61talist\x18\x01 \x03(\x0b\x32$.com.webank.ai.eggroll.api.core.Data"u\n\x0b\x43\x61llRequest\x12\x0f\n\x07isAsync\x18\x01 \x01(\x08\x12\x0f\n\x07timeout\x18\x02 \x01(\x03\x12\x0f\n\x07\x63ommand\x18\x03 \x01(\t\x12\x33\n\x05param\x18\x04 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Data"\x88\x01\n\x0c\x43\x61llResponse\x12\x42\n\x0creturnStatus\x18\x01 \x01(\x0b\x32,.com.webank.ai.eggroll.api.core.ReturnStatus\x12\x34\n\x06result\x18\x02 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Data""\n\x03Job\x12\r\n\x05jobId\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t"Y\n\x04Task\x12\x30\n\x03job\x18\x01 \x01(\x0b\x32#.com.webank.ai.eggroll.api.core.Job\x12\x0e\n\x06taskId\x18\x02 \x01(\x03\x12\x0f\n\x07tableId\x18\x03 \x01(\x03"N\n\x06Result\x12\x32\n\x04task\x18\x01 \x01(\x0b\x32$.com.webank.ai.eggroll.api.core.Task\x12\x10\n\x08resultId\x18\x02 \x01(\x03"-\n\x0cReturnStatus\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t"\xe2\x01\n\x0bSessionInfo\x12\x11\n\tsessionId\x18\x01 \x01(\t\x12\x61\n\x13\x63omputingEngineConf\x18\x02 \x03(\x0b\x32\x44.com.webank.ai.eggroll.api.core.SessionInfo.ComputingEngineConfEntry\x12\x14\n\x0cnamingPolicy\x18\x03 \x01(\t\x12\x0b\n\x03tag\x18\x04 \x01(\t\x1a:\n\x18\x43omputingEngineConfEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3' +) + + +_ENDPOINT = DESCRIPTOR.message_types_by_name["Endpoint"] +_ENDPOINTS = DESCRIPTOR.message_types_by_name["Endpoints"] +_DATA = DESCRIPTOR.message_types_by_name["Data"] +_REPEATEDDATA = DESCRIPTOR.message_types_by_name["RepeatedData"] +_CALLREQUEST = DESCRIPTOR.message_types_by_name["CallRequest"] +_CALLRESPONSE = DESCRIPTOR.message_types_by_name["CallResponse"] +_JOB = DESCRIPTOR.message_types_by_name["Job"] +_TASK = DESCRIPTOR.message_types_by_name["Task"] +_RESULT = DESCRIPTOR.message_types_by_name["Result"] +_RETURNSTATUS = DESCRIPTOR.message_types_by_name["ReturnStatus"] +_SESSIONINFO = DESCRIPTOR.message_types_by_name["SessionInfo"] +_SESSIONINFO_COMPUTINGENGINECONFENTRY = _SESSIONINFO.nested_types_by_name[ + "ComputingEngineConfEntry" +] +Endpoint = _reflection.GeneratedProtocolMessageType( + "Endpoint", + (_message.Message,), + { + "DESCRIPTOR": _ENDPOINT, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Endpoint) + }, +) _sym_db.RegisterMessage(Endpoint) -Endpoints = _reflection.GeneratedProtocolMessageType('Endpoints', (_message.Message,), { - 'DESCRIPTOR' : _ENDPOINTS, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Endpoints) - }) +Endpoints = _reflection.GeneratedProtocolMessageType( + "Endpoints", + (_message.Message,), + { + "DESCRIPTOR": _ENDPOINTS, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Endpoints) + }, +) _sym_db.RegisterMessage(Endpoints) -Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { - 'DESCRIPTOR' : _DATA, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Data) - }) +Data = _reflection.GeneratedProtocolMessageType( + "Data", + (_message.Message,), + { + "DESCRIPTOR": _DATA, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Data) + }, +) _sym_db.RegisterMessage(Data) -RepeatedData = _reflection.GeneratedProtocolMessageType('RepeatedData', (_message.Message,), { - 'DESCRIPTOR' : _REPEATEDDATA, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.RepeatedData) - }) +RepeatedData = _reflection.GeneratedProtocolMessageType( + "RepeatedData", + (_message.Message,), + { + "DESCRIPTOR": _REPEATEDDATA, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.RepeatedData) + }, +) _sym_db.RegisterMessage(RepeatedData) -CallRequest = _reflection.GeneratedProtocolMessageType('CallRequest', (_message.Message,), { - 'DESCRIPTOR' : _CALLREQUEST, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.CallRequest) - }) +CallRequest = _reflection.GeneratedProtocolMessageType( + "CallRequest", + (_message.Message,), + { + "DESCRIPTOR": _CALLREQUEST, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.CallRequest) + }, +) _sym_db.RegisterMessage(CallRequest) -CallResponse = _reflection.GeneratedProtocolMessageType('CallResponse', (_message.Message,), { - 'DESCRIPTOR' : _CALLRESPONSE, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.CallResponse) - }) +CallResponse = _reflection.GeneratedProtocolMessageType( + "CallResponse", + (_message.Message,), + { + "DESCRIPTOR": _CALLRESPONSE, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.CallResponse) + }, +) _sym_db.RegisterMessage(CallResponse) -Job = _reflection.GeneratedProtocolMessageType('Job', (_message.Message,), { - 'DESCRIPTOR' : _JOB, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Job) - }) +Job = _reflection.GeneratedProtocolMessageType( + "Job", + (_message.Message,), + { + "DESCRIPTOR": _JOB, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Job) + }, +) _sym_db.RegisterMessage(Job) -Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { - 'DESCRIPTOR' : _TASK, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Task) - }) +Task = _reflection.GeneratedProtocolMessageType( + "Task", + (_message.Message,), + { + "DESCRIPTOR": _TASK, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Task) + }, +) _sym_db.RegisterMessage(Task) -Result = _reflection.GeneratedProtocolMessageType('Result', (_message.Message,), { - 'DESCRIPTOR' : _RESULT, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Result) - }) +Result = _reflection.GeneratedProtocolMessageType( + "Result", + (_message.Message,), + { + "DESCRIPTOR": _RESULT, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.Result) + }, +) _sym_db.RegisterMessage(Result) -ReturnStatus = _reflection.GeneratedProtocolMessageType('ReturnStatus', (_message.Message,), { - 'DESCRIPTOR' : _RETURNSTATUS, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.ReturnStatus) - }) +ReturnStatus = _reflection.GeneratedProtocolMessageType( + "ReturnStatus", + (_message.Message,), + { + "DESCRIPTOR": _RETURNSTATUS, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.ReturnStatus) + }, +) _sym_db.RegisterMessage(ReturnStatus) -SessionInfo = _reflection.GeneratedProtocolMessageType('SessionInfo', (_message.Message,), { - - 'ComputingEngineConfEntry' : _reflection.GeneratedProtocolMessageType('ComputingEngineConfEntry', (_message.Message,), { - 'DESCRIPTOR' : _SESSIONINFO_COMPUTINGENGINECONFENTRY, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.SessionInfo.ComputingEngineConfEntry) - }) - , - 'DESCRIPTOR' : _SESSIONINFO, - '__module__' : 'basic_meta_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.SessionInfo) - }) +SessionInfo = _reflection.GeneratedProtocolMessageType( + "SessionInfo", + (_message.Message,), + { + "ComputingEngineConfEntry": _reflection.GeneratedProtocolMessageType( + "ComputingEngineConfEntry", + (_message.Message,), + { + "DESCRIPTOR": _SESSIONINFO_COMPUTINGENGINECONFENTRY, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.SessionInfo.ComputingEngineConfEntry) + }, + ), + "DESCRIPTOR": _SESSIONINFO, + "__module__": "basic_meta_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.core.SessionInfo) + }, +) _sym_db.RegisterMessage(SessionInfo) _sym_db.RegisterMessage(SessionInfo.ComputingEngineConfEntry) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _SESSIONINFO_COMPUTINGENGINECONFENTRY._options = None - _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_options = b'8\001' - _ENDPOINT._serialized_start=52 - _ENDPOINT._serialized_end=106 - _ENDPOINTS._serialized_start=108 - _ENDPOINTS._serialized_end=180 - _DATA._serialized_start=182 - _DATA._serialized_end=254 - _REPEATEDDATA._serialized_start=256 - _REPEATEDDATA._serialized_end=326 - _CALLREQUEST._serialized_start=328 - _CALLREQUEST._serialized_end=445 - _CALLRESPONSE._serialized_start=448 - _CALLRESPONSE._serialized_end=584 - _JOB._serialized_start=586 - _JOB._serialized_end=620 - _TASK._serialized_start=622 - _TASK._serialized_end=711 - _RESULT._serialized_start=713 - _RESULT._serialized_end=791 - _RETURNSTATUS._serialized_start=793 - _RETURNSTATUS._serialized_end=838 - _SESSIONINFO._serialized_start=841 - _SESSIONINFO._serialized_end=1067 - _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_start=1009 - _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_end=1067 + DESCRIPTOR._options = None + _SESSIONINFO_COMPUTINGENGINECONFENTRY._options = None + _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_options = b"8\001" + _ENDPOINT._serialized_start = 52 + _ENDPOINT._serialized_end = 106 + _ENDPOINTS._serialized_start = 108 + _ENDPOINTS._serialized_end = 180 + _DATA._serialized_start = 182 + _DATA._serialized_end = 254 + _REPEATEDDATA._serialized_start = 256 + _REPEATEDDATA._serialized_end = 326 + _CALLREQUEST._serialized_start = 328 + _CALLREQUEST._serialized_end = 445 + _CALLRESPONSE._serialized_start = 448 + _CALLRESPONSE._serialized_end = 584 + _JOB._serialized_start = 586 + _JOB._serialized_end = 620 + _TASK._serialized_start = 622 + _TASK._serialized_end = 711 + _RESULT._serialized_start = 713 + _RESULT._serialized_end = 791 + _RETURNSTATUS._serialized_start = 793 + _RETURNSTATUS._serialized_end = 838 + _SESSIONINFO._serialized_start = 841 + _SESSIONINFO._serialized_end = 1067 + _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_start = 1009 + _SESSIONINFO_COMPUTINGENGINECONFENTRY._serialized_end = 1067 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/default_empty_fill_pb2.py b/python/fate/arch/protobuf/python/default_empty_fill_pb2.py index 11cd69fa94..48dd9930d6 100644 --- a/python/fate/arch/protobuf/python/default_empty_fill_pb2.py +++ b/python/fate/arch/protobuf/python/default_empty_fill_pb2.py @@ -7,29 +7,33 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b"\n\x18\x64\x65\x66\x61ult-empty-fill.proto\x12&com.webank.ai.fate.core.mlmodel.buffer\"'\n\x17\x44\x65\x66\x61ultEmptyFillMessage\x12\x0c\n\x04\x66lag\x18\x01 \x01(\tB\x17\x42\x15\x44\x65\x66\x61ultEmptyFillProtob\x06proto3" +) -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x64\x65\x66\x61ult-empty-fill.proto\x12&com.webank.ai.fate.core.mlmodel.buffer\"\'\n\x17\x44\x65\x66\x61ultEmptyFillMessage\x12\x0c\n\x04\x66lag\x18\x01 \x01(\tB\x17\x42\x15\x44\x65\x66\x61ultEmptyFillProtob\x06proto3') - - - -_DEFAULTEMPTYFILLMESSAGE = DESCRIPTOR.message_types_by_name['DefaultEmptyFillMessage'] -DefaultEmptyFillMessage = _reflection.GeneratedProtocolMessageType('DefaultEmptyFillMessage', (_message.Message,), { - 'DESCRIPTOR' : _DEFAULTEMPTYFILLMESSAGE, - '__module__' : 'default_empty_fill_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.core.mlmodel.buffer.DefaultEmptyFillMessage) - }) +_DEFAULTEMPTYFILLMESSAGE = DESCRIPTOR.message_types_by_name["DefaultEmptyFillMessage"] +DefaultEmptyFillMessage = _reflection.GeneratedProtocolMessageType( + "DefaultEmptyFillMessage", + (_message.Message,), + { + "DESCRIPTOR": _DEFAULTEMPTYFILLMESSAGE, + "__module__": "default_empty_fill_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.core.mlmodel.buffer.DefaultEmptyFillMessage) + }, +) _sym_db.RegisterMessage(DefaultEmptyFillMessage) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'B\025DefaultEmptyFillProto' - _DEFAULTEMPTYFILLMESSAGE._serialized_start=68 - _DEFAULTEMPTYFILLMESSAGE._serialized_end=107 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"B\025DefaultEmptyFillProto" + _DEFAULTEMPTYFILLMESSAGE._serialized_start = 68 + _DEFAULTEMPTYFILLMESSAGE._serialized_end = 107 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/fate_data_structure_pb2.py b/python/fate/arch/protobuf/python/fate_data_structure_pb2.py index 4ef3f593d0..53f9cde0d8 100644 --- a/python/fate/arch/protobuf/python/fate_data_structure_pb2.py +++ b/python/fate/arch/protobuf/python/fate_data_structure_pb2.py @@ -7,62 +7,76 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x19\x66\x61te-data-structure.proto\x12\x1b\x63om.webank.ai.fate.api.core"&\n\x08RawEntry\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c"@\n\x06RawMap\x12\x36\n\x07\x65ntries\x18\x01 \x03(\x0b\x32%.com.webank.ai.fate.api.core.RawEntry"n\n\x04\x44ict\x12\x39\n\x04\x64ict\x18\x01 \x03(\x0b\x32+.com.webank.ai.fate.api.core.Dict.DictEntry\x1a+\n\tDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x42\x0f\x42\rDataStructureb\x06proto3' +) -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19\x66\x61te-data-structure.proto\x12\x1b\x63om.webank.ai.fate.api.core\"&\n\x08RawEntry\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\"@\n\x06RawMap\x12\x36\n\x07\x65ntries\x18\x01 \x03(\x0b\x32%.com.webank.ai.fate.api.core.RawEntry\"n\n\x04\x44ict\x12\x39\n\x04\x64ict\x18\x01 \x03(\x0b\x32+.com.webank.ai.fate.api.core.Dict.DictEntry\x1a+\n\tDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x42\x0f\x42\rDataStructureb\x06proto3') - - - -_RAWENTRY = DESCRIPTOR.message_types_by_name['RawEntry'] -_RAWMAP = DESCRIPTOR.message_types_by_name['RawMap'] -_DICT = DESCRIPTOR.message_types_by_name['Dict'] -_DICT_DICTENTRY = _DICT.nested_types_by_name['DictEntry'] -RawEntry = _reflection.GeneratedProtocolMessageType('RawEntry', (_message.Message,), { - 'DESCRIPTOR' : _RAWENTRY, - '__module__' : 'fate_data_structure_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.RawEntry) - }) +_RAWENTRY = DESCRIPTOR.message_types_by_name["RawEntry"] +_RAWMAP = DESCRIPTOR.message_types_by_name["RawMap"] +_DICT = DESCRIPTOR.message_types_by_name["Dict"] +_DICT_DICTENTRY = _DICT.nested_types_by_name["DictEntry"] +RawEntry = _reflection.GeneratedProtocolMessageType( + "RawEntry", + (_message.Message,), + { + "DESCRIPTOR": _RAWENTRY, + "__module__": "fate_data_structure_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.RawEntry) + }, +) _sym_db.RegisterMessage(RawEntry) -RawMap = _reflection.GeneratedProtocolMessageType('RawMap', (_message.Message,), { - 'DESCRIPTOR' : _RAWMAP, - '__module__' : 'fate_data_structure_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.RawMap) - }) +RawMap = _reflection.GeneratedProtocolMessageType( + "RawMap", + (_message.Message,), + { + "DESCRIPTOR": _RAWMAP, + "__module__": "fate_data_structure_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.RawMap) + }, +) _sym_db.RegisterMessage(RawMap) -Dict = _reflection.GeneratedProtocolMessageType('Dict', (_message.Message,), { - - 'DictEntry' : _reflection.GeneratedProtocolMessageType('DictEntry', (_message.Message,), { - 'DESCRIPTOR' : _DICT_DICTENTRY, - '__module__' : 'fate_data_structure_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.Dict.DictEntry) - }) - , - 'DESCRIPTOR' : _DICT, - '__module__' : 'fate_data_structure_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.Dict) - }) +Dict = _reflection.GeneratedProtocolMessageType( + "Dict", + (_message.Message,), + { + "DictEntry": _reflection.GeneratedProtocolMessageType( + "DictEntry", + (_message.Message,), + { + "DESCRIPTOR": _DICT_DICTENTRY, + "__module__": "fate_data_structure_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.Dict.DictEntry) + }, + ), + "DESCRIPTOR": _DICT, + "__module__": "fate_data_structure_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.core.Dict) + }, +) _sym_db.RegisterMessage(Dict) _sym_db.RegisterMessage(Dict.DictEntry) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'B\rDataStructure' - _DICT_DICTENTRY._options = None - _DICT_DICTENTRY._serialized_options = b'8\001' - _RAWENTRY._serialized_start=58 - _RAWENTRY._serialized_end=96 - _RAWMAP._serialized_start=98 - _RAWMAP._serialized_end=162 - _DICT._serialized_start=164 - _DICT._serialized_end=274 - _DICT_DICTENTRY._serialized_start=231 - _DICT_DICTENTRY._serialized_end=274 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"B\rDataStructure" + _DICT_DICTENTRY._options = None + _DICT_DICTENTRY._serialized_options = b"8\001" + _RAWENTRY._serialized_start = 58 + _RAWENTRY._serialized_end = 96 + _RAWMAP._serialized_start = 98 + _RAWMAP._serialized_end = 162 + _DICT._serialized_start = 164 + _DICT._serialized_end = 274 + _DICT_DICTENTRY._serialized_start = 231 + _DICT_DICTENTRY._serialized_end = 274 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/inference_service_pb2.py b/python/fate/arch/protobuf/python/inference_service_pb2.py index ca377a00ea..478b07ec89 100644 --- a/python/fate/arch/protobuf/python/inference_service_pb2.py +++ b/python/fate/arch/protobuf/python/inference_service_pb2.py @@ -7,32 +7,36 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x17inference_service.proto\x12\x1e\x63om.webank.ai.fate.api.serving"0\n\x10InferenceMessage\x12\x0e\n\x06header\x18\x01 \x01(\x0c\x12\x0c\n\x04\x62ody\x18\x02 \x01(\x0c\x32\xf6\x02\n\x10InferenceService\x12o\n\tinference\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessage\x12w\n\x11startInferenceJob\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessage\x12x\n\x12getInferenceResult\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessageB\x17\x42\x15InferenceServiceProtob\x06proto3' +) -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17inference_service.proto\x12\x1e\x63om.webank.ai.fate.api.serving\"0\n\x10InferenceMessage\x12\x0e\n\x06header\x18\x01 \x01(\x0c\x12\x0c\n\x04\x62ody\x18\x02 \x01(\x0c\x32\xf6\x02\n\x10InferenceService\x12o\n\tinference\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessage\x12w\n\x11startInferenceJob\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessage\x12x\n\x12getInferenceResult\x12\x30.com.webank.ai.fate.api.serving.InferenceMessage\x1a\x30.com.webank.ai.fate.api.serving.InferenceMessageB\x17\x42\x15InferenceServiceProtob\x06proto3') - - - -_INFERENCEMESSAGE = DESCRIPTOR.message_types_by_name['InferenceMessage'] -InferenceMessage = _reflection.GeneratedProtocolMessageType('InferenceMessage', (_message.Message,), { - 'DESCRIPTOR' : _INFERENCEMESSAGE, - '__module__' : 'inference_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.serving.InferenceMessage) - }) +_INFERENCEMESSAGE = DESCRIPTOR.message_types_by_name["InferenceMessage"] +InferenceMessage = _reflection.GeneratedProtocolMessageType( + "InferenceMessage", + (_message.Message,), + { + "DESCRIPTOR": _INFERENCEMESSAGE, + "__module__": "inference_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.serving.InferenceMessage) + }, +) _sym_db.RegisterMessage(InferenceMessage) -_INFERENCESERVICE = DESCRIPTOR.services_by_name['InferenceService'] +_INFERENCESERVICE = DESCRIPTOR.services_by_name["InferenceService"] if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'B\025InferenceServiceProto' - _INFERENCEMESSAGE._serialized_start=59 - _INFERENCEMESSAGE._serialized_end=107 - _INFERENCESERVICE._serialized_start=110 - _INFERENCESERVICE._serialized_end=484 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"B\025InferenceServiceProto" + _INFERENCEMESSAGE._serialized_start = 59 + _INFERENCEMESSAGE._serialized_end = 107 + _INFERENCESERVICE._serialized_start = 110 + _INFERENCESERVICE._serialized_end = 484 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/inference_service_pb2_grpc.py b/python/fate/arch/protobuf/python/inference_service_pb2_grpc.py index e982e15d7c..61f89988ed 100644 --- a/python/fate/arch/protobuf/python/inference_service_pb2_grpc.py +++ b/python/fate/arch/protobuf/python/inference_service_pb2_grpc.py @@ -1,7 +1,6 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc - import inference_service_pb2 as inference__service__pb2 @@ -15,20 +14,20 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.inference = channel.unary_unary( - '/com.webank.ai.fate.api.serving.InferenceService/inference', - request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - response_deserializer=inference__service__pb2.InferenceMessage.FromString, - ) + "/com.webank.ai.fate.api.serving.InferenceService/inference", + request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + response_deserializer=inference__service__pb2.InferenceMessage.FromString, + ) self.startInferenceJob = channel.unary_unary( - '/com.webank.ai.fate.api.serving.InferenceService/startInferenceJob', - request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - response_deserializer=inference__service__pb2.InferenceMessage.FromString, - ) + "/com.webank.ai.fate.api.serving.InferenceService/startInferenceJob", + request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + response_deserializer=inference__service__pb2.InferenceMessage.FromString, + ) self.getInferenceResult = channel.unary_unary( - '/com.webank.ai.fate.api.serving.InferenceService/getInferenceResult', - request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - response_deserializer=inference__service__pb2.InferenceMessage.FromString, - ) + "/com.webank.ai.fate.api.serving.InferenceService/getInferenceResult", + request_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + response_deserializer=inference__service__pb2.InferenceMessage.FromString, + ) class InferenceServiceServicer(object): @@ -37,96 +36,133 @@ class InferenceServiceServicer(object): def inference(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def startInferenceJob(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def getInferenceResult(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_InferenceServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'inference': grpc.unary_unary_rpc_method_handler( - servicer.inference, - request_deserializer=inference__service__pb2.InferenceMessage.FromString, - response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - ), - 'startInferenceJob': grpc.unary_unary_rpc_method_handler( - servicer.startInferenceJob, - request_deserializer=inference__service__pb2.InferenceMessage.FromString, - response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - ), - 'getInferenceResult': grpc.unary_unary_rpc_method_handler( - servicer.getInferenceResult, - request_deserializer=inference__service__pb2.InferenceMessage.FromString, - response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, - ), + "inference": grpc.unary_unary_rpc_method_handler( + servicer.inference, + request_deserializer=inference__service__pb2.InferenceMessage.FromString, + response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + ), + "startInferenceJob": grpc.unary_unary_rpc_method_handler( + servicer.startInferenceJob, + request_deserializer=inference__service__pb2.InferenceMessage.FromString, + response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + ), + "getInferenceResult": grpc.unary_unary_rpc_method_handler( + servicer.getInferenceResult, + request_deserializer=inference__service__pb2.InferenceMessage.FromString, + response_serializer=inference__service__pb2.InferenceMessage.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'com.webank.ai.fate.api.serving.InferenceService', rpc_method_handlers) + "com.webank.ai.fate.api.serving.InferenceService", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class InferenceService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def inference(request, + def inference( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.serving.InferenceService/inference', + "/com.webank.ai.fate.api.serving.InferenceService/inference", inference__service__pb2.InferenceMessage.SerializeToString, inference__service__pb2.InferenceMessage.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def startInferenceJob(request, + def startInferenceJob( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.serving.InferenceService/startInferenceJob', + "/com.webank.ai.fate.api.serving.InferenceService/startInferenceJob", inference__service__pb2.InferenceMessage.SerializeToString, inference__service__pb2.InferenceMessage.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def getInferenceResult(request, + def getInferenceResult( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.serving.InferenceService/getInferenceResult', + "/com.webank.ai.fate.api.serving.InferenceService/getInferenceResult", inference__service__pb2.InferenceMessage.SerializeToString, inference__service__pb2.InferenceMessage.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/fate/arch/protobuf/python/model_service_pb2.py b/python/fate/arch/protobuf/python/model_service_pb2.py index f19e82c8f5..dbcda1ee15 100644 --- a/python/fate/arch/protobuf/python/model_service_pb2.py +++ b/python/fate/arch/protobuf/python/model_service_pb2.py @@ -7,201 +7,265 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13model_service.proto\x12&com.webank.ai.fate.api.mlmodel.manager"\x18\n\x05Party\x12\x0f\n\x07partyId\x18\x01 \x03(\t"*\n\tLocalInfo\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07partyId\x18\x02 \x01(\t"1\n\tModelInfo\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t"\xd9\x01\n\rRoleModelInfo\x12_\n\rroleModelInfo\x18\x01 \x03(\x0b\x32H.com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo.RoleModelInfoEntry\x1ag\n\x12RoleModelInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12@\n\x05value\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.fate.api.mlmodel.manager.ModelInfo:\x02\x38\x01"5\n\rUnloadRequest\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t"5\n\x0eUnloadResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t"H\n\rUnbindRequest\x12\x11\n\tserviceId\x18\x01 \x01(\t\x12\x11\n\ttableName\x18\x02 \x01(\t\x12\x11\n\tnamespace\x18\x03 \x01(\t"5\n\x0eUnbindResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t"\x85\x01\n\x11QueryModelRequest\x12\x11\n\tserviceId\x18\x01 \x01(\t\x12\x11\n\ttableName\x18\x02 \x01(\t\x12\x11\n\tnamespace\x18\x03 \x01(\t\x12\x12\n\nbeginIndex\x18\x04 \x01(\x05\x12\x10\n\x08\x65ndIndex\x18\x05 \x01(\x05\x12\x11\n\tqueryType\x18\x06 \x01(\x05"\x0f\n\rModelBindInfo"f\n\x0bModelInfoEx\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\x12\x11\n\tserviceId\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\r\n\x05index\x18\x05 \x01(\x05"\x7f\n\x12QueryModelResponse\x12\x0f\n\x07retcode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\x12G\n\nmodelInfos\x18\x03 \x03(\x0b\x32\x33.com.webank.ai.fate.api.mlmodel.manager.ModelInfoEx"\x92\x04\n\x0ePublishRequest\x12@\n\x05local\x18\x01 \x01(\x0b\x32\x31.com.webank.ai.fate.api.mlmodel.manager.LocalInfo\x12N\n\x04role\x18\x02 \x03(\x0b\x32@.com.webank.ai.fate.api.mlmodel.manager.PublishRequest.RoleEntry\x12P\n\x05model\x18\x03 \x03(\x0b\x32\x41.com.webank.ai.fate.api.mlmodel.manager.PublishRequest.ModelEntry\x12\x11\n\tserviceId\x18\x04 \x01(\t\x12\x11\n\ttableName\x18\x05 \x01(\t\x12\x11\n\tnamespace\x18\x06 \x01(\t\x12\x10\n\x08loadType\x18\x07 \x01(\t\x12\x10\n\x08\x66ilePath\x18\x08 \x01(\t\x1aZ\n\tRoleEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12<\n\x05value\x18\x02 \x01(\x0b\x32-.com.webank.ai.fate.api.mlmodel.manager.Party:\x02\x38\x01\x1a\x63\n\nModelEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x44\n\x05value\x18\x02 \x01(\x0b\x32\x35.com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo:\x02\x38\x01"S\n\x0fPublishResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x32\x89\x06\n\x0cModelService\x12~\n\x0bpublishLoad\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12~\n\x0bpublishBind\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12\x80\x01\n\rpublishOnline\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12\x83\x01\n\nqueryModel\x12\x39.com.webank.ai.fate.api.mlmodel.manager.QueryModelRequest\x1a:.com.webank.ai.fate.api.mlmodel.manager.QueryModelResponse\x12w\n\x06unload\x12\x35.com.webank.ai.fate.api.mlmodel.manager.UnloadRequest\x1a\x36.com.webank.ai.fate.api.mlmodel.manager.UnloadResponse\x12w\n\x06unbind\x12\x35.com.webank.ai.fate.api.mlmodel.manager.UnbindRequest\x1a\x36.com.webank.ai.fate.api.mlmodel.manager.UnbindResponseB\x13\x42\x11ModelServiceProtob\x06proto3' +) -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13model_service.proto\x12&com.webank.ai.fate.api.mlmodel.manager\"\x18\n\x05Party\x12\x0f\n\x07partyId\x18\x01 \x03(\t\"*\n\tLocalInfo\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0f\n\x07partyId\x18\x02 \x01(\t\"1\n\tModelInfo\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xd9\x01\n\rRoleModelInfo\x12_\n\rroleModelInfo\x18\x01 \x03(\x0b\x32H.com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo.RoleModelInfoEntry\x1ag\n\x12RoleModelInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12@\n\x05value\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.fate.api.mlmodel.manager.ModelInfo:\x02\x38\x01\"5\n\rUnloadRequest\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"5\n\x0eUnloadResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\"H\n\rUnbindRequest\x12\x11\n\tserviceId\x18\x01 \x01(\t\x12\x11\n\ttableName\x18\x02 \x01(\t\x12\x11\n\tnamespace\x18\x03 \x01(\t\"5\n\x0eUnbindResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x85\x01\n\x11QueryModelRequest\x12\x11\n\tserviceId\x18\x01 \x01(\t\x12\x11\n\ttableName\x18\x02 \x01(\t\x12\x11\n\tnamespace\x18\x03 \x01(\t\x12\x12\n\nbeginIndex\x18\x04 \x01(\x05\x12\x10\n\x08\x65ndIndex\x18\x05 \x01(\x05\x12\x11\n\tqueryType\x18\x06 \x01(\x05\"\x0f\n\rModelBindInfo\"f\n\x0bModelInfoEx\x12\x11\n\ttableName\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\x12\x11\n\tserviceId\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\r\n\x05index\x18\x05 \x01(\x05\"\x7f\n\x12QueryModelResponse\x12\x0f\n\x07retcode\x18\x01 \x01(\t\x12\x0f\n\x07message\x18\x02 \x01(\t\x12G\n\nmodelInfos\x18\x03 \x03(\x0b\x32\x33.com.webank.ai.fate.api.mlmodel.manager.ModelInfoEx\"\x92\x04\n\x0ePublishRequest\x12@\n\x05local\x18\x01 \x01(\x0b\x32\x31.com.webank.ai.fate.api.mlmodel.manager.LocalInfo\x12N\n\x04role\x18\x02 \x03(\x0b\x32@.com.webank.ai.fate.api.mlmodel.manager.PublishRequest.RoleEntry\x12P\n\x05model\x18\x03 \x03(\x0b\x32\x41.com.webank.ai.fate.api.mlmodel.manager.PublishRequest.ModelEntry\x12\x11\n\tserviceId\x18\x04 \x01(\t\x12\x11\n\ttableName\x18\x05 \x01(\t\x12\x11\n\tnamespace\x18\x06 \x01(\t\x12\x10\n\x08loadType\x18\x07 \x01(\t\x12\x10\n\x08\x66ilePath\x18\x08 \x01(\t\x1aZ\n\tRoleEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12<\n\x05value\x18\x02 \x01(\x0b\x32-.com.webank.ai.fate.api.mlmodel.manager.Party:\x02\x38\x01\x1a\x63\n\nModelEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x44\n\x05value\x18\x02 \x01(\x0b\x32\x35.com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo:\x02\x38\x01\"S\n\x0fPublishResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x32\x89\x06\n\x0cModelService\x12~\n\x0bpublishLoad\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12~\n\x0bpublishBind\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12\x80\x01\n\rpublishOnline\x12\x36.com.webank.ai.fate.api.mlmodel.manager.PublishRequest\x1a\x37.com.webank.ai.fate.api.mlmodel.manager.PublishResponse\x12\x83\x01\n\nqueryModel\x12\x39.com.webank.ai.fate.api.mlmodel.manager.QueryModelRequest\x1a:.com.webank.ai.fate.api.mlmodel.manager.QueryModelResponse\x12w\n\x06unload\x12\x35.com.webank.ai.fate.api.mlmodel.manager.UnloadRequest\x1a\x36.com.webank.ai.fate.api.mlmodel.manager.UnloadResponse\x12w\n\x06unbind\x12\x35.com.webank.ai.fate.api.mlmodel.manager.UnbindRequest\x1a\x36.com.webank.ai.fate.api.mlmodel.manager.UnbindResponseB\x13\x42\x11ModelServiceProtob\x06proto3') - - - -_PARTY = DESCRIPTOR.message_types_by_name['Party'] -_LOCALINFO = DESCRIPTOR.message_types_by_name['LocalInfo'] -_MODELINFO = DESCRIPTOR.message_types_by_name['ModelInfo'] -_ROLEMODELINFO = DESCRIPTOR.message_types_by_name['RoleModelInfo'] -_ROLEMODELINFO_ROLEMODELINFOENTRY = _ROLEMODELINFO.nested_types_by_name['RoleModelInfoEntry'] -_UNLOADREQUEST = DESCRIPTOR.message_types_by_name['UnloadRequest'] -_UNLOADRESPONSE = DESCRIPTOR.message_types_by_name['UnloadResponse'] -_UNBINDREQUEST = DESCRIPTOR.message_types_by_name['UnbindRequest'] -_UNBINDRESPONSE = DESCRIPTOR.message_types_by_name['UnbindResponse'] -_QUERYMODELREQUEST = DESCRIPTOR.message_types_by_name['QueryModelRequest'] -_MODELBINDINFO = DESCRIPTOR.message_types_by_name['ModelBindInfo'] -_MODELINFOEX = DESCRIPTOR.message_types_by_name['ModelInfoEx'] -_QUERYMODELRESPONSE = DESCRIPTOR.message_types_by_name['QueryModelResponse'] -_PUBLISHREQUEST = DESCRIPTOR.message_types_by_name['PublishRequest'] -_PUBLISHREQUEST_ROLEENTRY = _PUBLISHREQUEST.nested_types_by_name['RoleEntry'] -_PUBLISHREQUEST_MODELENTRY = _PUBLISHREQUEST.nested_types_by_name['ModelEntry'] -_PUBLISHRESPONSE = DESCRIPTOR.message_types_by_name['PublishResponse'] -Party = _reflection.GeneratedProtocolMessageType('Party', (_message.Message,), { - 'DESCRIPTOR' : _PARTY, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.Party) - }) +_PARTY = DESCRIPTOR.message_types_by_name["Party"] +_LOCALINFO = DESCRIPTOR.message_types_by_name["LocalInfo"] +_MODELINFO = DESCRIPTOR.message_types_by_name["ModelInfo"] +_ROLEMODELINFO = DESCRIPTOR.message_types_by_name["RoleModelInfo"] +_ROLEMODELINFO_ROLEMODELINFOENTRY = _ROLEMODELINFO.nested_types_by_name[ + "RoleModelInfoEntry" +] +_UNLOADREQUEST = DESCRIPTOR.message_types_by_name["UnloadRequest"] +_UNLOADRESPONSE = DESCRIPTOR.message_types_by_name["UnloadResponse"] +_UNBINDREQUEST = DESCRIPTOR.message_types_by_name["UnbindRequest"] +_UNBINDRESPONSE = DESCRIPTOR.message_types_by_name["UnbindResponse"] +_QUERYMODELREQUEST = DESCRIPTOR.message_types_by_name["QueryModelRequest"] +_MODELBINDINFO = DESCRIPTOR.message_types_by_name["ModelBindInfo"] +_MODELINFOEX = DESCRIPTOR.message_types_by_name["ModelInfoEx"] +_QUERYMODELRESPONSE = DESCRIPTOR.message_types_by_name["QueryModelResponse"] +_PUBLISHREQUEST = DESCRIPTOR.message_types_by_name["PublishRequest"] +_PUBLISHREQUEST_ROLEENTRY = _PUBLISHREQUEST.nested_types_by_name["RoleEntry"] +_PUBLISHREQUEST_MODELENTRY = _PUBLISHREQUEST.nested_types_by_name["ModelEntry"] +_PUBLISHRESPONSE = DESCRIPTOR.message_types_by_name["PublishResponse"] +Party = _reflection.GeneratedProtocolMessageType( + "Party", + (_message.Message,), + { + "DESCRIPTOR": _PARTY, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.Party) + }, +) _sym_db.RegisterMessage(Party) -LocalInfo = _reflection.GeneratedProtocolMessageType('LocalInfo', (_message.Message,), { - 'DESCRIPTOR' : _LOCALINFO, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.LocalInfo) - }) +LocalInfo = _reflection.GeneratedProtocolMessageType( + "LocalInfo", + (_message.Message,), + { + "DESCRIPTOR": _LOCALINFO, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.LocalInfo) + }, +) _sym_db.RegisterMessage(LocalInfo) -ModelInfo = _reflection.GeneratedProtocolMessageType('ModelInfo', (_message.Message,), { - 'DESCRIPTOR' : _MODELINFO, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelInfo) - }) +ModelInfo = _reflection.GeneratedProtocolMessageType( + "ModelInfo", + (_message.Message,), + { + "DESCRIPTOR": _MODELINFO, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelInfo) + }, +) _sym_db.RegisterMessage(ModelInfo) -RoleModelInfo = _reflection.GeneratedProtocolMessageType('RoleModelInfo', (_message.Message,), { - - 'RoleModelInfoEntry' : _reflection.GeneratedProtocolMessageType('RoleModelInfoEntry', (_message.Message,), { - 'DESCRIPTOR' : _ROLEMODELINFO_ROLEMODELINFOENTRY, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo.RoleModelInfoEntry) - }) - , - 'DESCRIPTOR' : _ROLEMODELINFO, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo) - }) +RoleModelInfo = _reflection.GeneratedProtocolMessageType( + "RoleModelInfo", + (_message.Message,), + { + "RoleModelInfoEntry": _reflection.GeneratedProtocolMessageType( + "RoleModelInfoEntry", + (_message.Message,), + { + "DESCRIPTOR": _ROLEMODELINFO_ROLEMODELINFOENTRY, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo.RoleModelInfoEntry) + }, + ), + "DESCRIPTOR": _ROLEMODELINFO, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.RoleModelInfo) + }, +) _sym_db.RegisterMessage(RoleModelInfo) _sym_db.RegisterMessage(RoleModelInfo.RoleModelInfoEntry) -UnloadRequest = _reflection.GeneratedProtocolMessageType('UnloadRequest', (_message.Message,), { - 'DESCRIPTOR' : _UNLOADREQUEST, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnloadRequest) - }) +UnloadRequest = _reflection.GeneratedProtocolMessageType( + "UnloadRequest", + (_message.Message,), + { + "DESCRIPTOR": _UNLOADREQUEST, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnloadRequest) + }, +) _sym_db.RegisterMessage(UnloadRequest) -UnloadResponse = _reflection.GeneratedProtocolMessageType('UnloadResponse', (_message.Message,), { - 'DESCRIPTOR' : _UNLOADRESPONSE, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnloadResponse) - }) +UnloadResponse = _reflection.GeneratedProtocolMessageType( + "UnloadResponse", + (_message.Message,), + { + "DESCRIPTOR": _UNLOADRESPONSE, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnloadResponse) + }, +) _sym_db.RegisterMessage(UnloadResponse) -UnbindRequest = _reflection.GeneratedProtocolMessageType('UnbindRequest', (_message.Message,), { - 'DESCRIPTOR' : _UNBINDREQUEST, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnbindRequest) - }) +UnbindRequest = _reflection.GeneratedProtocolMessageType( + "UnbindRequest", + (_message.Message,), + { + "DESCRIPTOR": _UNBINDREQUEST, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnbindRequest) + }, +) _sym_db.RegisterMessage(UnbindRequest) -UnbindResponse = _reflection.GeneratedProtocolMessageType('UnbindResponse', (_message.Message,), { - 'DESCRIPTOR' : _UNBINDRESPONSE, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnbindResponse) - }) +UnbindResponse = _reflection.GeneratedProtocolMessageType( + "UnbindResponse", + (_message.Message,), + { + "DESCRIPTOR": _UNBINDRESPONSE, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.UnbindResponse) + }, +) _sym_db.RegisterMessage(UnbindResponse) -QueryModelRequest = _reflection.GeneratedProtocolMessageType('QueryModelRequest', (_message.Message,), { - 'DESCRIPTOR' : _QUERYMODELREQUEST, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.QueryModelRequest) - }) +QueryModelRequest = _reflection.GeneratedProtocolMessageType( + "QueryModelRequest", + (_message.Message,), + { + "DESCRIPTOR": _QUERYMODELREQUEST, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.QueryModelRequest) + }, +) _sym_db.RegisterMessage(QueryModelRequest) -ModelBindInfo = _reflection.GeneratedProtocolMessageType('ModelBindInfo', (_message.Message,), { - 'DESCRIPTOR' : _MODELBINDINFO, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelBindInfo) - }) +ModelBindInfo = _reflection.GeneratedProtocolMessageType( + "ModelBindInfo", + (_message.Message,), + { + "DESCRIPTOR": _MODELBINDINFO, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelBindInfo) + }, +) _sym_db.RegisterMessage(ModelBindInfo) -ModelInfoEx = _reflection.GeneratedProtocolMessageType('ModelInfoEx', (_message.Message,), { - 'DESCRIPTOR' : _MODELINFOEX, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelInfoEx) - }) +ModelInfoEx = _reflection.GeneratedProtocolMessageType( + "ModelInfoEx", + (_message.Message,), + { + "DESCRIPTOR": _MODELINFOEX, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.ModelInfoEx) + }, +) _sym_db.RegisterMessage(ModelInfoEx) -QueryModelResponse = _reflection.GeneratedProtocolMessageType('QueryModelResponse', (_message.Message,), { - 'DESCRIPTOR' : _QUERYMODELRESPONSE, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.QueryModelResponse) - }) +QueryModelResponse = _reflection.GeneratedProtocolMessageType( + "QueryModelResponse", + (_message.Message,), + { + "DESCRIPTOR": _QUERYMODELRESPONSE, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.QueryModelResponse) + }, +) _sym_db.RegisterMessage(QueryModelResponse) -PublishRequest = _reflection.GeneratedProtocolMessageType('PublishRequest', (_message.Message,), { - - 'RoleEntry' : _reflection.GeneratedProtocolMessageType('RoleEntry', (_message.Message,), { - 'DESCRIPTOR' : _PUBLISHREQUEST_ROLEENTRY, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest.RoleEntry) - }) - , - - 'ModelEntry' : _reflection.GeneratedProtocolMessageType('ModelEntry', (_message.Message,), { - 'DESCRIPTOR' : _PUBLISHREQUEST_MODELENTRY, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest.ModelEntry) - }) - , - 'DESCRIPTOR' : _PUBLISHREQUEST, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest) - }) +PublishRequest = _reflection.GeneratedProtocolMessageType( + "PublishRequest", + (_message.Message,), + { + "RoleEntry": _reflection.GeneratedProtocolMessageType( + "RoleEntry", + (_message.Message,), + { + "DESCRIPTOR": _PUBLISHREQUEST_ROLEENTRY, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest.RoleEntry) + }, + ), + "ModelEntry": _reflection.GeneratedProtocolMessageType( + "ModelEntry", + (_message.Message,), + { + "DESCRIPTOR": _PUBLISHREQUEST_MODELENTRY, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest.ModelEntry) + }, + ), + "DESCRIPTOR": _PUBLISHREQUEST, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishRequest) + }, +) _sym_db.RegisterMessage(PublishRequest) _sym_db.RegisterMessage(PublishRequest.RoleEntry) _sym_db.RegisterMessage(PublishRequest.ModelEntry) -PublishResponse = _reflection.GeneratedProtocolMessageType('PublishResponse', (_message.Message,), { - 'DESCRIPTOR' : _PUBLISHRESPONSE, - '__module__' : 'model_service_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishResponse) - }) +PublishResponse = _reflection.GeneratedProtocolMessageType( + "PublishResponse", + (_message.Message,), + { + "DESCRIPTOR": _PUBLISHRESPONSE, + "__module__": "model_service_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.fate.api.mlmodel.manager.PublishResponse) + }, +) _sym_db.RegisterMessage(PublishResponse) -_MODELSERVICE = DESCRIPTOR.services_by_name['ModelService'] +_MODELSERVICE = DESCRIPTOR.services_by_name["ModelService"] if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'B\021ModelServiceProto' - _ROLEMODELINFO_ROLEMODELINFOENTRY._options = None - _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_options = b'8\001' - _PUBLISHREQUEST_ROLEENTRY._options = None - _PUBLISHREQUEST_ROLEENTRY._serialized_options = b'8\001' - _PUBLISHREQUEST_MODELENTRY._options = None - _PUBLISHREQUEST_MODELENTRY._serialized_options = b'8\001' - _PARTY._serialized_start=63 - _PARTY._serialized_end=87 - _LOCALINFO._serialized_start=89 - _LOCALINFO._serialized_end=131 - _MODELINFO._serialized_start=133 - _MODELINFO._serialized_end=182 - _ROLEMODELINFO._serialized_start=185 - _ROLEMODELINFO._serialized_end=402 - _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_start=299 - _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_end=402 - _UNLOADREQUEST._serialized_start=404 - _UNLOADREQUEST._serialized_end=457 - _UNLOADRESPONSE._serialized_start=459 - _UNLOADRESPONSE._serialized_end=512 - _UNBINDREQUEST._serialized_start=514 - _UNBINDREQUEST._serialized_end=586 - _UNBINDRESPONSE._serialized_start=588 - _UNBINDRESPONSE._serialized_end=641 - _QUERYMODELREQUEST._serialized_start=644 - _QUERYMODELREQUEST._serialized_end=777 - _MODELBINDINFO._serialized_start=779 - _MODELBINDINFO._serialized_end=794 - _MODELINFOEX._serialized_start=796 - _MODELINFOEX._serialized_end=898 - _QUERYMODELRESPONSE._serialized_start=900 - _QUERYMODELRESPONSE._serialized_end=1027 - _PUBLISHREQUEST._serialized_start=1030 - _PUBLISHREQUEST._serialized_end=1560 - _PUBLISHREQUEST_ROLEENTRY._serialized_start=1369 - _PUBLISHREQUEST_ROLEENTRY._serialized_end=1459 - _PUBLISHREQUEST_MODELENTRY._serialized_start=1461 - _PUBLISHREQUEST_MODELENTRY._serialized_end=1560 - _PUBLISHRESPONSE._serialized_start=1562 - _PUBLISHRESPONSE._serialized_end=1645 - _MODELSERVICE._serialized_start=1648 - _MODELSERVICE._serialized_end=2425 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"B\021ModelServiceProto" + _ROLEMODELINFO_ROLEMODELINFOENTRY._options = None + _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_options = b"8\001" + _PUBLISHREQUEST_ROLEENTRY._options = None + _PUBLISHREQUEST_ROLEENTRY._serialized_options = b"8\001" + _PUBLISHREQUEST_MODELENTRY._options = None + _PUBLISHREQUEST_MODELENTRY._serialized_options = b"8\001" + _PARTY._serialized_start = 63 + _PARTY._serialized_end = 87 + _LOCALINFO._serialized_start = 89 + _LOCALINFO._serialized_end = 131 + _MODELINFO._serialized_start = 133 + _MODELINFO._serialized_end = 182 + _ROLEMODELINFO._serialized_start = 185 + _ROLEMODELINFO._serialized_end = 402 + _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_start = 299 + _ROLEMODELINFO_ROLEMODELINFOENTRY._serialized_end = 402 + _UNLOADREQUEST._serialized_start = 404 + _UNLOADREQUEST._serialized_end = 457 + _UNLOADRESPONSE._serialized_start = 459 + _UNLOADRESPONSE._serialized_end = 512 + _UNBINDREQUEST._serialized_start = 514 + _UNBINDREQUEST._serialized_end = 586 + _UNBINDRESPONSE._serialized_start = 588 + _UNBINDRESPONSE._serialized_end = 641 + _QUERYMODELREQUEST._serialized_start = 644 + _QUERYMODELREQUEST._serialized_end = 777 + _MODELBINDINFO._serialized_start = 779 + _MODELBINDINFO._serialized_end = 794 + _MODELINFOEX._serialized_start = 796 + _MODELINFOEX._serialized_end = 898 + _QUERYMODELRESPONSE._serialized_start = 900 + _QUERYMODELRESPONSE._serialized_end = 1027 + _PUBLISHREQUEST._serialized_start = 1030 + _PUBLISHREQUEST._serialized_end = 1560 + _PUBLISHREQUEST_ROLEENTRY._serialized_start = 1369 + _PUBLISHREQUEST_ROLEENTRY._serialized_end = 1459 + _PUBLISHREQUEST_MODELENTRY._serialized_start = 1461 + _PUBLISHREQUEST_MODELENTRY._serialized_end = 1560 + _PUBLISHRESPONSE._serialized_start = 1562 + _PUBLISHRESPONSE._serialized_end = 1645 + _MODELSERVICE._serialized_start = 1648 + _MODELSERVICE._serialized_end = 2425 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/model_service_pb2_grpc.py b/python/fate/arch/protobuf/python/model_service_pb2_grpc.py index 4e081ad62a..f9d5d247c2 100644 --- a/python/fate/arch/protobuf/python/model_service_pb2_grpc.py +++ b/python/fate/arch/protobuf/python/model_service_pb2_grpc.py @@ -1,7 +1,6 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc - import model_service_pb2 as model__service__pb2 @@ -15,35 +14,35 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.publishLoad = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishLoad', - request_serializer=model__service__pb2.PublishRequest.SerializeToString, - response_deserializer=model__service__pb2.PublishResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishLoad", + request_serializer=model__service__pb2.PublishRequest.SerializeToString, + response_deserializer=model__service__pb2.PublishResponse.FromString, + ) self.publishBind = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishBind', - request_serializer=model__service__pb2.PublishRequest.SerializeToString, - response_deserializer=model__service__pb2.PublishResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishBind", + request_serializer=model__service__pb2.PublishRequest.SerializeToString, + response_deserializer=model__service__pb2.PublishResponse.FromString, + ) self.publishOnline = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishOnline', - request_serializer=model__service__pb2.PublishRequest.SerializeToString, - response_deserializer=model__service__pb2.PublishResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishOnline", + request_serializer=model__service__pb2.PublishRequest.SerializeToString, + response_deserializer=model__service__pb2.PublishResponse.FromString, + ) self.queryModel = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/queryModel', - request_serializer=model__service__pb2.QueryModelRequest.SerializeToString, - response_deserializer=model__service__pb2.QueryModelResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/queryModel", + request_serializer=model__service__pb2.QueryModelRequest.SerializeToString, + response_deserializer=model__service__pb2.QueryModelResponse.FromString, + ) self.unload = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/unload', - request_serializer=model__service__pb2.UnloadRequest.SerializeToString, - response_deserializer=model__service__pb2.UnloadResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/unload", + request_serializer=model__service__pb2.UnloadRequest.SerializeToString, + response_deserializer=model__service__pb2.UnloadResponse.FromString, + ) self.unbind = channel.unary_unary( - '/com.webank.ai.fate.api.mlmodel.manager.ModelService/unbind', - request_serializer=model__service__pb2.UnbindRequest.SerializeToString, - response_deserializer=model__service__pb2.UnbindResponse.FromString, - ) + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/unbind", + request_serializer=model__service__pb2.UnbindRequest.SerializeToString, + response_deserializer=model__service__pb2.UnbindResponse.FromString, + ) class ModelServiceServicer(object): @@ -52,180 +51,253 @@ class ModelServiceServicer(object): def publishLoad(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def publishBind(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def publishOnline(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def queryModel(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def unload(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def unbind(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_ModelServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'publishLoad': grpc.unary_unary_rpc_method_handler( - servicer.publishLoad, - request_deserializer=model__service__pb2.PublishRequest.FromString, - response_serializer=model__service__pb2.PublishResponse.SerializeToString, - ), - 'publishBind': grpc.unary_unary_rpc_method_handler( - servicer.publishBind, - request_deserializer=model__service__pb2.PublishRequest.FromString, - response_serializer=model__service__pb2.PublishResponse.SerializeToString, - ), - 'publishOnline': grpc.unary_unary_rpc_method_handler( - servicer.publishOnline, - request_deserializer=model__service__pb2.PublishRequest.FromString, - response_serializer=model__service__pb2.PublishResponse.SerializeToString, - ), - 'queryModel': grpc.unary_unary_rpc_method_handler( - servicer.queryModel, - request_deserializer=model__service__pb2.QueryModelRequest.FromString, - response_serializer=model__service__pb2.QueryModelResponse.SerializeToString, - ), - 'unload': grpc.unary_unary_rpc_method_handler( - servicer.unload, - request_deserializer=model__service__pb2.UnloadRequest.FromString, - response_serializer=model__service__pb2.UnloadResponse.SerializeToString, - ), - 'unbind': grpc.unary_unary_rpc_method_handler( - servicer.unbind, - request_deserializer=model__service__pb2.UnbindRequest.FromString, - response_serializer=model__service__pb2.UnbindResponse.SerializeToString, - ), + "publishLoad": grpc.unary_unary_rpc_method_handler( + servicer.publishLoad, + request_deserializer=model__service__pb2.PublishRequest.FromString, + response_serializer=model__service__pb2.PublishResponse.SerializeToString, + ), + "publishBind": grpc.unary_unary_rpc_method_handler( + servicer.publishBind, + request_deserializer=model__service__pb2.PublishRequest.FromString, + response_serializer=model__service__pb2.PublishResponse.SerializeToString, + ), + "publishOnline": grpc.unary_unary_rpc_method_handler( + servicer.publishOnline, + request_deserializer=model__service__pb2.PublishRequest.FromString, + response_serializer=model__service__pb2.PublishResponse.SerializeToString, + ), + "queryModel": grpc.unary_unary_rpc_method_handler( + servicer.queryModel, + request_deserializer=model__service__pb2.QueryModelRequest.FromString, + response_serializer=model__service__pb2.QueryModelResponse.SerializeToString, + ), + "unload": grpc.unary_unary_rpc_method_handler( + servicer.unload, + request_deserializer=model__service__pb2.UnloadRequest.FromString, + response_serializer=model__service__pb2.UnloadResponse.SerializeToString, + ), + "unbind": grpc.unary_unary_rpc_method_handler( + servicer.unbind, + request_deserializer=model__service__pb2.UnbindRequest.FromString, + response_serializer=model__service__pb2.UnbindResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'com.webank.ai.fate.api.mlmodel.manager.ModelService', rpc_method_handlers) + "com.webank.ai.fate.api.mlmodel.manager.ModelService", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class ModelService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def publishLoad(request, + def publishLoad( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishLoad', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishLoad", model__service__pb2.PublishRequest.SerializeToString, model__service__pb2.PublishResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def publishBind(request, + def publishBind( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishBind', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishBind", model__service__pb2.PublishRequest.SerializeToString, model__service__pb2.PublishResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def publishOnline(request, + def publishOnline( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishOnline', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/publishOnline", model__service__pb2.PublishRequest.SerializeToString, model__service__pb2.PublishResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def queryModel(request, + def queryModel( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/queryModel', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/queryModel", model__service__pb2.QueryModelRequest.SerializeToString, model__service__pb2.QueryModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def unload(request, + def unload( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/unload', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/unload", model__service__pb2.UnloadRequest.SerializeToString, model__service__pb2.UnloadResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def unbind(request, + def unbind( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.fate.api.mlmodel.manager.ModelService/unbind', + "/com.webank.ai.fate.api.mlmodel.manager.ModelService/unbind", model__service__pb2.UnbindRequest.SerializeToString, model__service__pb2.UnbindResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/fate/arch/protobuf/python/proxy_pb2.py b/python/fate/arch/protobuf/python/proxy_pb2.py index 4c8b785d7f..1fd36d1078 100644 --- a/python/fate/arch/protobuf/python/proxy_pb2.py +++ b/python/fate/arch/protobuf/python/proxy_pb2.py @@ -2,12 +2,13 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: proxy.proto """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import enum_type_wrapper + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -15,10 +16,11 @@ import basic_meta_pb2 as basic__meta__pb2 +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0bproxy.proto\x12*com.webank.ai.eggroll.api.networking.proxy\x1a\x10\x62\x61sic-meta.proto"&\n\x05Model\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07\x64\x61taKey\x18\x02 \x01(\t"X\n\x04Task\x12\x0e\n\x06taskId\x18\x01 \x01(\t\x12@\n\x05model\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Model"p\n\x05Topic\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07partyId\x18\x02 \x01(\t\x12\x0c\n\x04role\x18\x03 \x01(\t\x12:\n\x08\x63\x61llback\x18\x04 \x01(\x0b\x32(.com.webank.ai.eggroll.api.core.Endpoint"\x17\n\x07\x43ommand\x12\x0c\n\x04name\x18\x01 \x01(\t"p\n\x04\x43onf\x12\x16\n\x0eoverallTimeout\x18\x01 \x01(\x03\x12\x1d\n\x15\x63ompletionWaitTimeout\x18\x02 \x01(\x03\x12\x1d\n\x15packetIntervalTimeout\x18\x03 \x01(\x03\x12\x12\n\nmaxRetries\x18\x04 \x01(\x05"\x9a\x03\n\x08Metadata\x12>\n\x04task\x18\x01 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Task\x12>\n\x03src\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x12>\n\x03\x64st\x18\x03 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x12\x44\n\x07\x63ommand\x18\x04 \x01(\x0b\x32\x33.com.webank.ai.eggroll.api.networking.proxy.Command\x12\x10\n\x08operator\x18\x05 \x01(\t\x12\x0b\n\x03seq\x18\x06 \x01(\x03\x12\x0b\n\x03\x61\x63k\x18\x07 \x01(\x03\x12>\n\x04\x63onf\x18\x08 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Conf\x12\x0b\n\x03\x65xt\x18\t \x01(\x0c\x12\x0f\n\x07version\x18\x64 \x01(\t""\n\x04\x44\x61ta\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c"\x8e\x01\n\x06Packet\x12\x44\n\x06header\x18\x01 \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12>\n\x04\x62ody\x18\x02 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Data"\xa3\x01\n\x11HeartbeatResponse\x12\x44\n\x06header\x18\x01 \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12H\n\toperation\x18\x02 \x01(\x0e\x32\x35.com.webank.ai.eggroll.api.networking.proxy.Operation"\xc5\x01\n\x0cPollingFrame\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x0b\n\x03seq\x18\x02 \x01(\x03\x12\x46\n\x08metadata\x18\n \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12\x42\n\x06packet\x18\x14 \x01(\x0b\x32\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x12\x0c\n\x04\x64\x65sc\x18\x1e \x01(\t*O\n\tOperation\x12\t\n\x05START\x10\x00\x12\x07\n\x03RUN\x10\x01\x12\x08\n\x04STOP\x10\x02\x12\x08\n\x04KILL\x10\x03\x12\x0c\n\x08GET_DATA\x10\x04\x12\x0c\n\x08PUT_DATA\x10\x05\x32\xf6\x03\n\x13\x44\x61taTransferService\x12r\n\x04push\x12\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x1a\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata(\x01\x12r\n\x04pull\x12\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x1a\x32.com.webank.ai.eggroll.api.networking.proxy.Packet0\x01\x12s\n\tunaryCall\x12\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x1a\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x12\x81\x01\n\x07polling\x12\x38.com.webank.ai.eggroll.api.networking.proxy.PollingFrame\x1a\x38.com.webank.ai.eggroll.api.networking.proxy.PollingFrame(\x01\x30\x01\x32t\n\x0cRouteService\x12\x64\n\x05query\x12\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x1a(.com.webank.ai.eggroll.api.core.Endpointb\x06proto3' +) -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bproxy.proto\x12*com.webank.ai.eggroll.api.networking.proxy\x1a\x10\x62\x61sic-meta.proto\"&\n\x05Model\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07\x64\x61taKey\x18\x02 \x01(\t\"X\n\x04Task\x12\x0e\n\x06taskId\x18\x01 \x01(\t\x12@\n\x05model\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Model\"p\n\x05Topic\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07partyId\x18\x02 \x01(\t\x12\x0c\n\x04role\x18\x03 \x01(\t\x12:\n\x08\x63\x61llback\x18\x04 \x01(\x0b\x32(.com.webank.ai.eggroll.api.core.Endpoint\"\x17\n\x07\x43ommand\x12\x0c\n\x04name\x18\x01 \x01(\t\"p\n\x04\x43onf\x12\x16\n\x0eoverallTimeout\x18\x01 \x01(\x03\x12\x1d\n\x15\x63ompletionWaitTimeout\x18\x02 \x01(\x03\x12\x1d\n\x15packetIntervalTimeout\x18\x03 \x01(\x03\x12\x12\n\nmaxRetries\x18\x04 \x01(\x05\"\x9a\x03\n\x08Metadata\x12>\n\x04task\x18\x01 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Task\x12>\n\x03src\x18\x02 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x12>\n\x03\x64st\x18\x03 \x01(\x0b\x32\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x12\x44\n\x07\x63ommand\x18\x04 \x01(\x0b\x32\x33.com.webank.ai.eggroll.api.networking.proxy.Command\x12\x10\n\x08operator\x18\x05 \x01(\t\x12\x0b\n\x03seq\x18\x06 \x01(\x03\x12\x0b\n\x03\x61\x63k\x18\x07 \x01(\x03\x12>\n\x04\x63onf\x18\x08 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Conf\x12\x0b\n\x03\x65xt\x18\t \x01(\x0c\x12\x0f\n\x07version\x18\x64 \x01(\t\"\"\n\x04\x44\x61ta\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\"\x8e\x01\n\x06Packet\x12\x44\n\x06header\x18\x01 \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12>\n\x04\x62ody\x18\x02 \x01(\x0b\x32\x30.com.webank.ai.eggroll.api.networking.proxy.Data\"\xa3\x01\n\x11HeartbeatResponse\x12\x44\n\x06header\x18\x01 \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12H\n\toperation\x18\x02 \x01(\x0e\x32\x35.com.webank.ai.eggroll.api.networking.proxy.Operation\"\xc5\x01\n\x0cPollingFrame\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x0b\n\x03seq\x18\x02 \x01(\x03\x12\x46\n\x08metadata\x18\n \x01(\x0b\x32\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x12\x42\n\x06packet\x18\x14 \x01(\x0b\x32\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x12\x0c\n\x04\x64\x65sc\x18\x1e \x01(\t*O\n\tOperation\x12\t\n\x05START\x10\x00\x12\x07\n\x03RUN\x10\x01\x12\x08\n\x04STOP\x10\x02\x12\x08\n\x04KILL\x10\x03\x12\x0c\n\x08GET_DATA\x10\x04\x12\x0c\n\x08PUT_DATA\x10\x05\x32\xf6\x03\n\x13\x44\x61taTransferService\x12r\n\x04push\x12\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x1a\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata(\x01\x12r\n\x04pull\x12\x34.com.webank.ai.eggroll.api.networking.proxy.Metadata\x1a\x32.com.webank.ai.eggroll.api.networking.proxy.Packet0\x01\x12s\n\tunaryCall\x12\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x1a\x32.com.webank.ai.eggroll.api.networking.proxy.Packet\x12\x81\x01\n\x07polling\x12\x38.com.webank.ai.eggroll.api.networking.proxy.PollingFrame\x1a\x38.com.webank.ai.eggroll.api.networking.proxy.PollingFrame(\x01\x30\x01\x32t\n\x0cRouteService\x12\x64\n\x05query\x12\x31.com.webank.ai.eggroll.api.networking.proxy.Topic\x1a(.com.webank.ai.eggroll.api.core.Endpointb\x06proto3') - -_OPERATION = DESCRIPTOR.enum_types_by_name['Operation'] +_OPERATION = DESCRIPTOR.enum_types_by_name["Operation"] Operation = enum_type_wrapper.EnumTypeWrapper(_OPERATION) START = 0 RUN = 1 @@ -28,115 +30,155 @@ PUT_DATA = 5 -_MODEL = DESCRIPTOR.message_types_by_name['Model'] -_TASK = DESCRIPTOR.message_types_by_name['Task'] -_TOPIC = DESCRIPTOR.message_types_by_name['Topic'] -_COMMAND = DESCRIPTOR.message_types_by_name['Command'] -_CONF = DESCRIPTOR.message_types_by_name['Conf'] -_METADATA = DESCRIPTOR.message_types_by_name['Metadata'] -_DATA = DESCRIPTOR.message_types_by_name['Data'] -_PACKET = DESCRIPTOR.message_types_by_name['Packet'] -_HEARTBEATRESPONSE = DESCRIPTOR.message_types_by_name['HeartbeatResponse'] -_POLLINGFRAME = DESCRIPTOR.message_types_by_name['PollingFrame'] -Model = _reflection.GeneratedProtocolMessageType('Model', (_message.Message,), { - 'DESCRIPTOR' : _MODEL, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Model) - }) +_MODEL = DESCRIPTOR.message_types_by_name["Model"] +_TASK = DESCRIPTOR.message_types_by_name["Task"] +_TOPIC = DESCRIPTOR.message_types_by_name["Topic"] +_COMMAND = DESCRIPTOR.message_types_by_name["Command"] +_CONF = DESCRIPTOR.message_types_by_name["Conf"] +_METADATA = DESCRIPTOR.message_types_by_name["Metadata"] +_DATA = DESCRIPTOR.message_types_by_name["Data"] +_PACKET = DESCRIPTOR.message_types_by_name["Packet"] +_HEARTBEATRESPONSE = DESCRIPTOR.message_types_by_name["HeartbeatResponse"] +_POLLINGFRAME = DESCRIPTOR.message_types_by_name["PollingFrame"] +Model = _reflection.GeneratedProtocolMessageType( + "Model", + (_message.Message,), + { + "DESCRIPTOR": _MODEL, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Model) + }, +) _sym_db.RegisterMessage(Model) -Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { - 'DESCRIPTOR' : _TASK, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Task) - }) +Task = _reflection.GeneratedProtocolMessageType( + "Task", + (_message.Message,), + { + "DESCRIPTOR": _TASK, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Task) + }, +) _sym_db.RegisterMessage(Task) -Topic = _reflection.GeneratedProtocolMessageType('Topic', (_message.Message,), { - 'DESCRIPTOR' : _TOPIC, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Topic) - }) +Topic = _reflection.GeneratedProtocolMessageType( + "Topic", + (_message.Message,), + { + "DESCRIPTOR": _TOPIC, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Topic) + }, +) _sym_db.RegisterMessage(Topic) -Command = _reflection.GeneratedProtocolMessageType('Command', (_message.Message,), { - 'DESCRIPTOR' : _COMMAND, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Command) - }) +Command = _reflection.GeneratedProtocolMessageType( + "Command", + (_message.Message,), + { + "DESCRIPTOR": _COMMAND, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Command) + }, +) _sym_db.RegisterMessage(Command) -Conf = _reflection.GeneratedProtocolMessageType('Conf', (_message.Message,), { - 'DESCRIPTOR' : _CONF, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Conf) - }) +Conf = _reflection.GeneratedProtocolMessageType( + "Conf", + (_message.Message,), + { + "DESCRIPTOR": _CONF, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Conf) + }, +) _sym_db.RegisterMessage(Conf) -Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), { - 'DESCRIPTOR' : _METADATA, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Metadata) - }) +Metadata = _reflection.GeneratedProtocolMessageType( + "Metadata", + (_message.Message,), + { + "DESCRIPTOR": _METADATA, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Metadata) + }, +) _sym_db.RegisterMessage(Metadata) -Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { - 'DESCRIPTOR' : _DATA, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Data) - }) +Data = _reflection.GeneratedProtocolMessageType( + "Data", + (_message.Message,), + { + "DESCRIPTOR": _DATA, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Data) + }, +) _sym_db.RegisterMessage(Data) -Packet = _reflection.GeneratedProtocolMessageType('Packet', (_message.Message,), { - 'DESCRIPTOR' : _PACKET, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Packet) - }) +Packet = _reflection.GeneratedProtocolMessageType( + "Packet", + (_message.Message,), + { + "DESCRIPTOR": _PACKET, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.Packet) + }, +) _sym_db.RegisterMessage(Packet) -HeartbeatResponse = _reflection.GeneratedProtocolMessageType('HeartbeatResponse', (_message.Message,), { - 'DESCRIPTOR' : _HEARTBEATRESPONSE, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.HeartbeatResponse) - }) +HeartbeatResponse = _reflection.GeneratedProtocolMessageType( + "HeartbeatResponse", + (_message.Message,), + { + "DESCRIPTOR": _HEARTBEATRESPONSE, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.HeartbeatResponse) + }, +) _sym_db.RegisterMessage(HeartbeatResponse) -PollingFrame = _reflection.GeneratedProtocolMessageType('PollingFrame', (_message.Message,), { - 'DESCRIPTOR' : _POLLINGFRAME, - '__module__' : 'proxy_pb2' - # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.PollingFrame) - }) +PollingFrame = _reflection.GeneratedProtocolMessageType( + "PollingFrame", + (_message.Message,), + { + "DESCRIPTOR": _POLLINGFRAME, + "__module__": "proxy_pb2" + # @@protoc_insertion_point(class_scope:com.webank.ai.eggroll.api.networking.proxy.PollingFrame) + }, +) _sym_db.RegisterMessage(PollingFrame) -_DATATRANSFERSERVICE = DESCRIPTOR.services_by_name['DataTransferService'] -_ROUTESERVICE = DESCRIPTOR.services_by_name['RouteService'] +_DATATRANSFERSERVICE = DESCRIPTOR.services_by_name["DataTransferService"] +_ROUTESERVICE = DESCRIPTOR.services_by_name["RouteService"] if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _OPERATION._serialized_start=1420 - _OPERATION._serialized_end=1499 - _MODEL._serialized_start=77 - _MODEL._serialized_end=115 - _TASK._serialized_start=117 - _TASK._serialized_end=205 - _TOPIC._serialized_start=207 - _TOPIC._serialized_end=319 - _COMMAND._serialized_start=321 - _COMMAND._serialized_end=344 - _CONF._serialized_start=346 - _CONF._serialized_end=458 - _METADATA._serialized_start=461 - _METADATA._serialized_end=871 - _DATA._serialized_start=873 - _DATA._serialized_end=907 - _PACKET._serialized_start=910 - _PACKET._serialized_end=1052 - _HEARTBEATRESPONSE._serialized_start=1055 - _HEARTBEATRESPONSE._serialized_end=1218 - _POLLINGFRAME._serialized_start=1221 - _POLLINGFRAME._serialized_end=1418 - _DATATRANSFERSERVICE._serialized_start=1502 - _DATATRANSFERSERVICE._serialized_end=2004 - _ROUTESERVICE._serialized_start=2006 - _ROUTESERVICE._serialized_end=2122 + DESCRIPTOR._options = None + _OPERATION._serialized_start = 1420 + _OPERATION._serialized_end = 1499 + _MODEL._serialized_start = 77 + _MODEL._serialized_end = 115 + _TASK._serialized_start = 117 + _TASK._serialized_end = 205 + _TOPIC._serialized_start = 207 + _TOPIC._serialized_end = 319 + _COMMAND._serialized_start = 321 + _COMMAND._serialized_end = 344 + _CONF._serialized_start = 346 + _CONF._serialized_end = 458 + _METADATA._serialized_start = 461 + _METADATA._serialized_end = 871 + _DATA._serialized_start = 873 + _DATA._serialized_end = 907 + _PACKET._serialized_start = 910 + _PACKET._serialized_end = 1052 + _HEARTBEATRESPONSE._serialized_start = 1055 + _HEARTBEATRESPONSE._serialized_end = 1218 + _POLLINGFRAME._serialized_start = 1221 + _POLLINGFRAME._serialized_end = 1418 + _DATATRANSFERSERVICE._serialized_start = 1502 + _DATATRANSFERSERVICE._serialized_end = 2004 + _ROUTESERVICE._serialized_start = 2006 + _ROUTESERVICE._serialized_end = 2122 # @@protoc_insertion_point(module_scope) diff --git a/python/fate/arch/protobuf/python/proxy_pb2_grpc.py b/python/fate/arch/protobuf/python/proxy_pb2_grpc.py index d9f99065c2..1d58ededce 100644 --- a/python/fate/arch/protobuf/python/proxy_pb2_grpc.py +++ b/python/fate/arch/protobuf/python/proxy_pb2_grpc.py @@ -1,14 +1,12 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" -import grpc - import basic_meta_pb2 as basic__meta__pb2 +import grpc import proxy_pb2 as proxy__pb2 class DataTransferServiceStub(object): - """data transfer service - """ + """data transfer service""" def __init__(self, channel): """Constructor. @@ -17,156 +15,204 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.push = channel.stream_unary( - '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/push', - request_serializer=proxy__pb2.Packet.SerializeToString, - response_deserializer=proxy__pb2.Metadata.FromString, - ) + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/push", + request_serializer=proxy__pb2.Packet.SerializeToString, + response_deserializer=proxy__pb2.Metadata.FromString, + ) self.pull = channel.unary_stream( - '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/pull', - request_serializer=proxy__pb2.Metadata.SerializeToString, - response_deserializer=proxy__pb2.Packet.FromString, - ) + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/pull", + request_serializer=proxy__pb2.Metadata.SerializeToString, + response_deserializer=proxy__pb2.Packet.FromString, + ) self.unaryCall = channel.unary_unary( - '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/unaryCall', - request_serializer=proxy__pb2.Packet.SerializeToString, - response_deserializer=proxy__pb2.Packet.FromString, - ) + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/unaryCall", + request_serializer=proxy__pb2.Packet.SerializeToString, + response_deserializer=proxy__pb2.Packet.FromString, + ) self.polling = channel.stream_stream( - '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/polling', - request_serializer=proxy__pb2.PollingFrame.SerializeToString, - response_deserializer=proxy__pb2.PollingFrame.FromString, - ) + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/polling", + request_serializer=proxy__pb2.PollingFrame.SerializeToString, + response_deserializer=proxy__pb2.PollingFrame.FromString, + ) class DataTransferServiceServicer(object): - """data transfer service - """ + """data transfer service""" def push(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def pull(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def unaryCall(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def polling(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_DataTransferServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'push': grpc.stream_unary_rpc_method_handler( - servicer.push, - request_deserializer=proxy__pb2.Packet.FromString, - response_serializer=proxy__pb2.Metadata.SerializeToString, - ), - 'pull': grpc.unary_stream_rpc_method_handler( - servicer.pull, - request_deserializer=proxy__pb2.Metadata.FromString, - response_serializer=proxy__pb2.Packet.SerializeToString, - ), - 'unaryCall': grpc.unary_unary_rpc_method_handler( - servicer.unaryCall, - request_deserializer=proxy__pb2.Packet.FromString, - response_serializer=proxy__pb2.Packet.SerializeToString, - ), - 'polling': grpc.stream_stream_rpc_method_handler( - servicer.polling, - request_deserializer=proxy__pb2.PollingFrame.FromString, - response_serializer=proxy__pb2.PollingFrame.SerializeToString, - ), + "push": grpc.stream_unary_rpc_method_handler( + servicer.push, + request_deserializer=proxy__pb2.Packet.FromString, + response_serializer=proxy__pb2.Metadata.SerializeToString, + ), + "pull": grpc.unary_stream_rpc_method_handler( + servicer.pull, + request_deserializer=proxy__pb2.Metadata.FromString, + response_serializer=proxy__pb2.Packet.SerializeToString, + ), + "unaryCall": grpc.unary_unary_rpc_method_handler( + servicer.unaryCall, + request_deserializer=proxy__pb2.Packet.FromString, + response_serializer=proxy__pb2.Packet.SerializeToString, + ), + "polling": grpc.stream_stream_rpc_method_handler( + servicer.polling, + request_deserializer=proxy__pb2.PollingFrame.FromString, + response_serializer=proxy__pb2.PollingFrame.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'com.webank.ai.eggroll.api.networking.proxy.DataTransferService', rpc_method_handlers) + "com.webank.ai.eggroll.api.networking.proxy.DataTransferService", + rpc_method_handlers, + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class DataTransferService(object): - """data transfer service - """ + """data transfer service""" @staticmethod - def push(request_iterator, + def push( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_unary( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_unary(request_iterator, target, '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/push', + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/push", proxy__pb2.Packet.SerializeToString, proxy__pb2.Metadata.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def pull(request, + def pull( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/pull', + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/pull", proxy__pb2.Metadata.SerializeToString, proxy__pb2.Packet.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def unaryCall(request, + def unaryCall( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/unaryCall', + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/unaryCall", proxy__pb2.Packet.SerializeToString, proxy__pb2.Packet.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) @staticmethod - def polling(request_iterator, + def polling( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream(request_iterator, target, '/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/polling', + "/com.webank.ai.eggroll.api.networking.proxy.DataTransferService/polling", proxy__pb2.PollingFrame.SerializeToString, proxy__pb2.PollingFrame.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) class RouteServiceStub(object): @@ -179,10 +225,10 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.query = channel.unary_unary( - '/com.webank.ai.eggroll.api.networking.proxy.RouteService/query', - request_serializer=proxy__pb2.Topic.SerializeToString, - response_deserializer=basic__meta__pb2.Endpoint.FromString, - ) + "/com.webank.ai.eggroll.api.networking.proxy.RouteService/query", + request_serializer=proxy__pb2.Topic.SerializeToString, + response_deserializer=basic__meta__pb2.Endpoint.FromString, + ) class RouteServiceServicer(object): @@ -191,40 +237,53 @@ class RouteServiceServicer(object): def query(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_RouteServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'query': grpc.unary_unary_rpc_method_handler( - servicer.query, - request_deserializer=proxy__pb2.Topic.FromString, - response_serializer=basic__meta__pb2.Endpoint.SerializeToString, - ), + "query": grpc.unary_unary_rpc_method_handler( + servicer.query, + request_deserializer=proxy__pb2.Topic.FromString, + response_serializer=basic__meta__pb2.Endpoint.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'com.webank.ai.eggroll.api.networking.proxy.RouteService', rpc_method_handlers) + "com.webank.ai.eggroll.api.networking.proxy.RouteService", rpc_method_handlers + ) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class RouteService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def query(request, + def query( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/com.webank.ai.eggroll.api.networking.proxy.RouteService/query', + "/com.webank.ai.eggroll.api.networking.proxy.RouteService/query", proxy__pb2.Topic.SerializeToString, basic__meta__pb2.Endpoint.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/fate/arch/relation_ship.py b/python/fate/arch/relation_ship.py index d53c3c8d07..33b4d50d8c 100644 --- a/python/fate/arch/relation_ship.py +++ b/python/fate/arch/relation_ship.py @@ -13,13 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.computing import ComputingEngine -from fate_arch.federation import FederationEngine -from fate_arch.storage import StorageEngine -from fate_arch.common.address import StandaloneAddress, EggRollAddress, HDFSAddress, \ - MysqlAddress, \ - PathAddress, LocalFSAddress, HiveAddress, LinkisHiveAddress, ApiAddress -from fate_arch.common import EngineType +from .common import EngineType +from .common.address import ( + ApiAddress, + EggRollAddress, + HDFSAddress, + HiveAddress, + LinkisHiveAddress, + LocalFSAddress, + MysqlAddress, + PathAddress, + StandaloneAddress, +) +from .computing import ComputingEngine +from .federation import FederationEngine +from .storage import StorageEngine class Relationship(object): @@ -27,43 +35,55 @@ class Relationship(object): ComputingEngine.STANDALONE: { EngineType.STORAGE: { "default": StorageEngine.STANDALONE, - "support": [StorageEngine.STANDALONE] + "support": [StorageEngine.STANDALONE], }, EngineType.FEDERATION: { "default": FederationEngine.STANDALONE, - "support": [FederationEngine.STANDALONE, FederationEngine.RABBITMQ, FederationEngine.PULSAR] + "support": [ + FederationEngine.STANDALONE, + FederationEngine.RABBITMQ, + FederationEngine.PULSAR, + ], }, }, ComputingEngine.EGGROLL: { EngineType.STORAGE: { "default": StorageEngine.EGGROLL, - "support": [StorageEngine.EGGROLL] + "support": [StorageEngine.EGGROLL], }, EngineType.FEDERATION: { "default": FederationEngine.EGGROLL, - "support": [FederationEngine.EGGROLL, FederationEngine.RABBITMQ, FederationEngine.PULSAR] + "support": [ + FederationEngine.EGGROLL, + FederationEngine.RABBITMQ, + FederationEngine.PULSAR, + ], }, }, ComputingEngine.SPARK: { EngineType.STORAGE: { "default": StorageEngine.HDFS, - "support": [StorageEngine.HDFS, StorageEngine.HIVE, StorageEngine.LOCALFS] + "support": [ + StorageEngine.HDFS, + StorageEngine.HIVE, + StorageEngine.LOCALFS, + ], }, EngineType.FEDERATION: { "default": FederationEngine.RABBITMQ, - "support": [FederationEngine.PULSAR, FederationEngine.RABBITMQ] + "support": [FederationEngine.PULSAR, FederationEngine.RABBITMQ], }, }, ComputingEngine.LINKIS_SPARK: { EngineType.STORAGE: { "default": StorageEngine.LINKIS_HIVE, - "support": [StorageEngine.LINKIS_HIVE] + "support": [StorageEngine.LINKIS_HIVE], }, EngineType.FEDERATION: { "default": FederationEngine.RABBITMQ, - "support": [FederationEngine.PULSAR, FederationEngine.RABBITMQ] + "support": [FederationEngine.PULSAR, FederationEngine.RABBITMQ], }, - } + }, } EngineToAddress = { @@ -75,14 +95,14 @@ class Relationship(object): StorageEngine.LINKIS_HIVE: LinkisHiveAddress, StorageEngine.LOCALFS: LocalFSAddress, StorageEngine.PATH: PathAddress, - StorageEngine.API: ApiAddress + StorageEngine.API: ApiAddress, } EngineConfMap = { "fate_on_standalone": { EngineType.COMPUTING: [(ComputingEngine.STANDALONE, "standalone")], EngineType.STORAGE: [(StorageEngine.STANDALONE, "standalone")], - EngineType.FEDERATION: [(FederationEngine.STANDALONE, "standalone")] + EngineType.FEDERATION: [(FederationEngine.STANDALONE, "standalone")], }, "fate_on_eggroll": { EngineType.COMPUTING: [(ComputingEngine.EGGROLL, "clustermanager")], @@ -90,9 +110,19 @@ class Relationship(object): EngineType.FEDERATION: [(FederationEngine.EGGROLL, "rollsite")], }, "fate_on_spark": { - EngineType.COMPUTING: [(ComputingEngine.SPARK, "spark"), (ComputingEngine.LINKIS_SPARK, "linkis_spark")], - EngineType.STORAGE: [(StorageEngine.HDFS, "hdfs"), (StorageEngine.HIVE, "hive"), - (StorageEngine.LINKIS_HIVE, "linkis_hive"), (StorageEngine.LOCALFS, "localfs")], - EngineType.FEDERATION: [(FederationEngine.RABBITMQ, "rabbitmq"), (FederationEngine.PULSAR, "pulsar")] + EngineType.COMPUTING: [ + (ComputingEngine.SPARK, "spark"), + (ComputingEngine.LINKIS_SPARK, "linkis_spark"), + ], + EngineType.STORAGE: [ + (StorageEngine.HDFS, "hdfs"), + (StorageEngine.HIVE, "hive"), + (StorageEngine.LINKIS_HIVE, "linkis_hive"), + (StorageEngine.LOCALFS, "localfs"), + ], + EngineType.FEDERATION: [ + (FederationEngine.RABBITMQ, "rabbitmq"), + (FederationEngine.PULSAR, "pulsar"), + ], }, } diff --git a/python/fate/arch/session/__init__.py b/python/fate/arch/session/__init__.py index ad3752e87a..21ca7d2dcf 100644 --- a/python/fate/arch/session/__init__.py +++ b/python/fate/arch/session/__init__.py @@ -15,16 +15,23 @@ # -from fate_arch.computing import is_table -from fate_arch.common._parties import PartiesInfo, Role -from fate_arch.session._session import Session, computing_session, get_session, get_parties, get_computing_session +from ..common._parties import PartiesInfo, Role +from ..computing import is_table +from ._session import ( + Session, + computing_session, + get_computing_session, + get_parties, + get_session, +) __all__ = [ - 'is_table', - 'Session', - 'PartiesInfo', - 'computing_session', - 'get_session', - 'get_parties', - 'get_computing_session', - 'Role'] + "is_table", + "Session", + "PartiesInfo", + "computing_session", + "get_session", + "get_parties", + "get_computing_session", + "Role", +] diff --git a/python/fate/arch/session/_session.py b/python/fate/arch/session/_session.py index e31020a6aa..458ef38f5a 100644 --- a/python/fate/arch/session/_session.py +++ b/python/fate/arch/session/_session.py @@ -19,15 +19,20 @@ import peewee -from fate_arch.abc import CSessionABC, FederationABC, CTableABC, StorageSessionABC, StorageTableABC, StorageTableMetaABC -from fate_arch.common import engine_utils, EngineType, Party -from fate_arch.common import log, base_utils -from fate_arch.common import remote_status -from fate_arch.common._parties import PartiesInfo -from fate_arch.computing import ComputingEngine -from fate_arch.federation import FederationEngine -from fate_arch.metastore.db_models import DB, SessionRecord, init_database_tables -from fate_arch.storage import StorageEngine, StorageSessionBase +from ..abc import ( + CSessionABC, + CTableABC, + FederationABC, + StorageSessionABC, + StorageTableABC, + StorageTableMetaABC, +) +from ..common import EngineType, Party, base_utils, engine_utils, log, remote_status +from ..common._parties import PartiesInfo +from ..computing import ComputingEngine +from ..federation import FederationEngine +from ..metastore.db_models import DB, SessionRecord, init_database_tables +from ..storage import StorageEngine, StorageSessionBase LOGGER = log.getLogger() @@ -65,7 +70,11 @@ def __init__(self, session_id: str = None, options=None): self._parties_info: typing.Optional[PartiesInfo] = None self._all_party_info: typing.List[Party] = [] self._session_id = str(uuid.uuid1()) if not session_id else session_id - self._logger = LOGGER if options.get("logger", None) is None else options.get("logger", None) + self._logger = ( + LOGGER + if options.get("logger", None) is None + else options.get("logger", None) + ) self._logger.info(f"create manager session {self._session_id}") @@ -90,29 +99,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._logger.exception("", exc_info=(exc_type, exc_val, exc_tb)) return self._close() - def init_computing(self, - computing_session_id: str = None, - record: bool = True, - **kwargs): - computing_session_id = f"{self._session_id}_computing_{uuid.uuid1()}" if not computing_session_id else computing_session_id + def init_computing( + self, computing_session_id: typing.Optional[str] = None, record: bool = True, **kwargs + ): + computing_session_id = ( + f"{self._session_id}_computing_{uuid.uuid1()}" + if not computing_session_id + else computing_session_id + ) if self.is_computing_valid: raise RuntimeError(f"computing session already valid") if record: - self.save_record(engine_type=EngineType.COMPUTING, - engine_name=self._computing_type, - engine_session_id=computing_session_id) + self.save_record( + engine_type=EngineType.COMPUTING, + engine_name=self._computing_type, + engine_session_id=computing_session_id, + ) if self._computing_type == ComputingEngine.STANDALONE: - from fate_arch.computing.standalone import CSession + from ..computing.standalone import CSession options = kwargs.get("options", {}) - self._computing_session = CSession(session_id=computing_session_id, options=options) + self._computing_session = CSession( + session_id=computing_session_id, options=options + ) self._computing_type = ComputingEngine.STANDALONE return self if self._computing_type == ComputingEngine.EGGROLL: - from fate_arch.computing.eggroll import CSession + from ..computing.eggroll import CSession options = kwargs.get("options", {}) self._computing_session = CSession( @@ -121,14 +137,15 @@ def init_computing(self, return self if self._computing_type == ComputingEngine.SPARK: - from fate_arch.computing.spark import CSession + from ..computing.spark import CSession self._computing_session = CSession(session_id=computing_session_id) self._computing_type = ComputingEngine.SPARK return self if self._computing_type == ComputingEngine.LINKIS_SPARK: - from fate_arch.computing.spark import CSession + from ..computing.spark import CSession + self._computing_session = CSession(session_id=computing_session_id) self._computing_type = ComputingEngine.LINKIS_SPARK return self @@ -136,35 +153,39 @@ def init_computing(self, raise RuntimeError(f"{self._computing_type} not supported") def init_federation( - self, - federation_session_id: str, - *, - runtime_conf: typing.Optional[dict] = None, - parties_info: typing.Optional[PartiesInfo] = None, - service_conf: typing.Optional[dict] = None, - record: bool = True, + self, + federation_session_id: str, + *, + runtime_conf: typing.Optional[dict] = None, + parties_info: typing.Optional[PartiesInfo] = None, + service_conf: typing.Optional[dict] = None, + record: bool = True, ): if record: - self.save_record(engine_type=EngineType.FEDERATION, - engine_name=self._federation_type, - engine_session_id=federation_session_id, - engine_runtime_conf={"runtime_conf": runtime_conf, "service_conf": service_conf}) + self.save_record( + engine_type=EngineType.FEDERATION, + engine_name=self._federation_type, + engine_session_id=federation_session_id, + engine_runtime_conf={ + "runtime_conf": runtime_conf, + "service_conf": service_conf, + }, + ) if parties_info is None: if runtime_conf is None: raise RuntimeError(f"`party_info` and `runtime_conf` are both `None`") parties_info = PartiesInfo.from_conf(runtime_conf) self._parties_info = parties_info - self._all_party_info = [Party(k, p) for k, v in runtime_conf['role'].items() for p in v] if self.is_federation_valid: raise RuntimeError("federation session already valid") if self._federation_type == FederationEngine.STANDALONE: - from fate_arch.computing.standalone import CSession - from fate_arch.federation.standalone import Federation + from ..computing.standalone import CSession + from ..federation.standalone import Federation if not self.is_computing_valid or not isinstance( - self._computing_session, CSession + self._computing_session, CSession ): raise RuntimeError( f"require computing with type {ComputingEngine.STANDALONE} valid" @@ -178,11 +199,11 @@ def init_federation( return self if self._federation_type == FederationEngine.EGGROLL: - from fate_arch.computing.eggroll import CSession - from fate_arch.federation.eggroll import Federation + from ..computing.eggroll import CSession + from ..federation.eggroll import Federation if not self.is_computing_valid or not isinstance( - self._computing_session, CSession + self._computing_session, CSession ): raise RuntimeError( f"require computing with type {ComputingEngine.EGGROLL} valid" @@ -197,7 +218,7 @@ def init_federation( return self if self._federation_type == FederationEngine.RABBITMQ: - from fate_arch.federation.rabbitmq import Federation + from ..federation.rabbitmq import Federation self._federation_session = Federation.from_conf( federation_session_id=federation_session_id, @@ -209,7 +230,7 @@ def init_federation( # Add pulsar support if self._federation_type == FederationEngine.PULSAR: - from fate_arch.federation.pulsar import Federation + from ..federation.pulsar import Federation self._federation_session = Federation.from_conf( federation_session_id=federation_session_id, @@ -221,12 +242,18 @@ def init_federation( raise RuntimeError(f"{self._federation_type} not supported") - def _get_or_create_storage(self, - storage_session_id=None, - storage_engine=None, - record: bool = True, - **kwargs) -> StorageSessionABC: - storage_session_id = f"{self._session_id}_storage_{uuid.uuid1()}" if not storage_session_id else storage_session_id + def _get_or_create_storage( + self, + storage_session_id=None, + storage_engine=None, + record: bool = True, + **kwargs, + ) -> StorageSessionABC: + storage_session_id = ( + f"{self._session_id}_storage_{uuid.uuid1()}" + if not storage_session_id + else storage_session_id + ) if storage_session_id in self._storage_session: return self._storage_session[storage_session_id] @@ -239,54 +266,87 @@ def _get_or_create_storage(self, return session if record: - self.save_record(engine_type=EngineType.STORAGE, - engine_name=storage_engine, - engine_session_id=storage_session_id) + self.save_record( + engine_type=EngineType.STORAGE, + engine_name=storage_engine, + engine_session_id=storage_session_id, + ) if storage_engine == StorageEngine.EGGROLL: - from fate_arch.storage.eggroll import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.eggroll import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.STANDALONE: - from fate_arch.storage.standalone import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.standalone import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.MYSQL: - from fate_arch.storage.mysql import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.mysql import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.HDFS: - from fate_arch.storage.hdfs import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.hdfs import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.HIVE: - from fate_arch.storage.hive import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.hive import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.LINKIS_HIVE: - from fate_arch.storage.linkis_hive import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.linkis_hive import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.PATH: - from fate_arch.storage.path import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.path import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.LOCALFS: - from fate_arch.storage.localfs import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.localfs import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) elif storage_engine == StorageEngine.API: - from fate_arch.storage.api import StorageSession - storage_session = StorageSession(session_id=storage_session_id, options=kwargs.get("options", {})) + from ..storage.api import StorageSession + + storage_session = StorageSession( + session_id=storage_session_id, options=kwargs.get("options", {}) + ) else: - raise NotImplementedError(f"can not be initialized with storage engine: {storage_engine}") + raise NotImplementedError( + f"can not be initialized with storage engine: {storage_engine}" + ) self._storage_session[storage_session_id] = storage_session return storage_session - def get_table(self, name, namespace, ignore_disable=False) -> typing.Union[StorageTableABC, None]: + def get_table( + self, name, namespace, ignore_disable=False + ) -> typing.Union[StorageTableABC, None]: meta = Session.get_table_meta(name=name, namespace=namespace) if meta is None: return None @@ -303,17 +363,29 @@ def get_table_meta(cls, name, namespace) -> typing.Union[StorageTableMetaABC, No return meta @classmethod - def persistent(cls, computing_table: CTableABC, namespace, name, schema=None, part_of_data=None, - engine=None, engine_address=None, store_type=None, token: typing.Dict = None) -> StorageTableMetaABC: - return StorageSessionBase.persistent(computing_table=computing_table, - namespace=namespace, - name=name, - schema=schema, - part_of_data=part_of_data, - engine=engine, - engine_address=engine_address, - store_type=store_type, - token=token) + def persistent( + cls, + computing_table: CTableABC, + namespace, + name, + schema=None, + part_of_data=None, + engine=None, + engine_address=None, + store_type=None, + token: typing.Dict = None, + ) -> StorageTableMetaABC: + return StorageSessionBase.persistent( + computing_table=computing_table, + namespace=namespace, + name=name, + schema=schema, + part_of_data=part_of_data, + engine=engine, + engine_address=engine_address, + store_type=store_type, + token=token, + ) @property def computing(self) -> CSessionABC: @@ -339,19 +411,26 @@ def is_federation_valid(self): return self._federation_session is not None @DB.connection_context() - def save_record(self, engine_type, engine_name, engine_session_id, engine_runtime_conf=None): + def save_record( + self, engine_type, engine_name, engine_session_id, engine_runtime_conf=None + ): self._logger.info( f"try to save session record for manager {self._session_id}, {engine_type} {engine_name}" - f" {engine_session_id}") + f" {engine_session_id}" + ) session_record = SessionRecord() session_record.f_manager_session_id = self._session_id session_record.f_engine_type = engine_type session_record.f_engine_name = engine_name session_record.f_engine_session_id = engine_session_id - session_record.f_engine_address = engine_runtime_conf if engine_runtime_conf else {} + session_record.f_engine_address = ( + engine_runtime_conf if engine_runtime_conf else {} + ) session_record.f_create_time = base_utils.current_timestamp() - msg = f"save storage session record for manager {self._session_id}, {engine_type} {engine_name} " \ - f"{engine_session_id}" + msg = ( + f"save storage session record for manager {self._session_id}, {engine_type} {engine_name} " + f"{engine_session_id}" + ) try: effect_count = session_record.save(force_insert=True) if effect_count != 1: @@ -362,15 +441,26 @@ def save_record(self, engine_type, engine_name, engine_session_id, engine_runtim raise RuntimeError(f"{msg} exception", e) self._logger.info( f"save session record for manager {self._session_id}, {engine_type} {engine_name} " - f"{engine_session_id} successfully") + f"{engine_session_id} successfully" + ) @DB.connection_context() def delete_session_record(self, engine_session_id, manager_session_id=None): if not manager_session_id: - rows = SessionRecord.delete().where(SessionRecord.f_engine_session_id == engine_session_id).execute() + rows = ( + SessionRecord.delete() + .where(SessionRecord.f_engine_session_id == engine_session_id) + .execute() + ) else: - rows = SessionRecord.delete().where(SessionRecord.f_engine_session_id == engine_session_id, - SessionRecord.f_manager_session_id == manager_session_id).execute() + rows = ( + SessionRecord.delete() + .where( + SessionRecord.f_engine_session_id == engine_session_id, + SessionRecord.f_manager_session_id == manager_session_id, + ) + .execute() + ) if rows > 0: self._logger.info(f"delete session {engine_session_id} record successfully") else: @@ -380,7 +470,9 @@ def delete_session_record(self, engine_session_id, manager_session_id=None): @DB.connection_context() def query_sessions(cls, reverse=None, order_by=None, **kwargs): try: - session_records = SessionRecord.query(reverse=reverse, order_by=order_by, **kwargs) + session_records = SessionRecord.query( + reverse=reverse, order_by=order_by, **kwargs + ) return session_records except BaseException: return [] @@ -388,24 +480,38 @@ def query_sessions(cls, reverse=None, order_by=None, **kwargs): @DB.connection_context() def get_session_from_record(self, **kwargs): self._logger.info(f"query by manager session id {self._session_id}") - session_records = self.query_sessions(manager_session_id=self.session_id, **kwargs) - self._logger.info([session_record.f_engine_session_id for session_record in session_records]) + session_records = self.query_sessions( + manager_session_id=self.session_id, **kwargs + ) + self._logger.info( + [session_record.f_engine_session_id for session_record in session_records] + ) for session_record in session_records: try: engine_session_id = session_record.f_engine_session_id if session_record.f_engine_type == EngineType.COMPUTING: - self._init_computing_if_not_valid(computing_session_id=engine_session_id) + self._init_computing_if_not_valid( + computing_session_id=engine_session_id + ) elif session_record.f_engine_type == EngineType.STORAGE: - self._get_or_create_storage(storage_session_id=engine_session_id, - storage_engine=session_record.f_engine_name, - record=False) + self._get_or_create_storage( + storage_session_id=engine_session_id, + storage_engine=session_record.f_engine_name, + record=False, + ) elif session_record.f_engine_type == EngineType.FEDERATION: - self._logger.info(f"engine runtime conf: {session_record.f_engine_address}") - self._init_federation_if_not_valid(federation_session_id=engine_session_id, - engine_runtime_conf=session_record.f_engine_address) + self._logger.info( + f"engine runtime conf: {session_record.f_engine_address}" + ) + self._init_federation_if_not_valid( + federation_session_id=engine_session_id, + engine_runtime_conf=session_record.f_engine_address, + ) except Exception as e: self._logger.info(e) - self.delete_session_record(engine_session_id=session_record.f_engine_session_id) + self.delete_session_record( + engine_session_id=session_record.f_engine_session_id + ) def _init_computing_if_not_valid(self, computing_session_id): if not self.is_computing_valid: @@ -414,7 +520,8 @@ def _init_computing_if_not_valid(self, computing_session_id): elif self._computing_session.session_id != computing_session_id: self._logger.warning( f"manager session had computing session {self._computing_session.session_id} " - f"different with query from db session {computing_session_id}") + f"different with query from db session {computing_session_id}" + ) return False else: # already exists @@ -423,42 +530,61 @@ def _init_computing_if_not_valid(self, computing_session_id): def _init_federation_if_not_valid(self, federation_session_id, engine_runtime_conf): if not self.is_federation_valid: try: - self._logger.info(f"init federation session {federation_session_id} type {self._federation_type}") - self.init_federation(federation_session_id=federation_session_id, - runtime_conf=engine_runtime_conf.get("runtime_conf"), - service_conf=engine_runtime_conf.get("service_conf"), - record=False) - self._logger.info(f"init federation session {federation_session_id} type {self._federation_type} done") + self._logger.info( + f"init federation session {federation_session_id} type {self._federation_type}" + ) + self.init_federation( + federation_session_id=federation_session_id, + runtime_conf=engine_runtime_conf.get("runtime_conf"), + service_conf=engine_runtime_conf.get("service_conf"), + record=False, + ) + self._logger.info( + f"init federation session {federation_session_id} type {self._federation_type} done" + ) return True except Exception as e: self._logger.warning( - f"init federation session {federation_session_id} type {self._federation_type} failed: {e}") + f"init federation session {federation_session_id} type {self._federation_type} failed: {e}" + ) return False elif self._federation_session.session_id != federation_session_id: self._logger.warning( - f"manager session had federation session {self._federation_session.session_id} different with query from db session {federation_session_id}") + f"manager session had federation session {self._federation_session.session_id} different with query from db session {federation_session_id}" + ) return False else: # already exists return True def destroy_all_sessions(self, **kwargs): - self._logger.info(f"start destroy manager session {self._session_id} all sessions") + self._logger.info( + f"start destroy manager session {self._session_id} all sessions" + ) self.get_session_from_record(**kwargs) self.destroy_federation_session() self.destroy_storage_session() self.destroy_computing_session() - self._logger.info(f"finish destroy manager session {self._session_id} all sessions") + self._logger.info( + f"finish destroy manager session {self._session_id} all sessions" + ) def destroy_computing_session(self): if self.is_computing_valid: try: - self._logger.info(f"try to destroy computing session {self._computing_session.session_id}") + self._logger.info( + f"try to destroy computing session {self._computing_session.session_id}" + ) self._computing_session.destroy() except Exception as e: - self._logger.info(f"destroy computing session {self._computing_session.session_id} failed", e) + self._logger.info( + f"destroy computing session {self._computing_session.session_id} failed", + e, + ) - self.delete_session_record(engine_session_id=self._computing_session.session_id) + self.delete_session_record( + engine_session_id=self._computing_session.session_id + ) self._computing_session = None def destroy_storage_session(self): @@ -468,7 +594,9 @@ def destroy_storage_session(self): session.destroy() self._logger.info(f"destroy storage session {session_id} successfully") except Exception as e: - self._logger.exception(f"destroy storage session {session_id} failed", e) + self._logger.exception( + f"destroy storage session {session_id} failed", e + ) self.delete_session_record(engine_session_id=session_id) @@ -480,14 +608,19 @@ def destroy_federation_session(self): if self._parties_info.local_party.role != "local": self._logger.info( f"try to destroy federation session {self._federation_session.session_id} type" - f" {EngineType.FEDERATION} role {self._parties_info.local_party.role}") - self._federation_session.destroy(parties=self._all_party_info) - self._logger.info(f"destroy federation session {self._federation_session.session_id} done") + f" {EngineType.FEDERATION} role {self._parties_info.local_party.role}" + ) + self._federation_session.destroy(parties=self._parties_info.all_parties) + self._logger.info( + f"destroy federation session {self._federation_session.session_id} done" + ) except Exception as e: self._logger.info(f"destroy federation failed: {e}") - self.delete_session_record(engine_session_id=self._federation_session.session_id, - manager_session_id=self.session_id) + self.delete_session_record( + engine_session_id=self._federation_session.session_id, + manager_session_id=self.session_id, + ) self._federation_session = None def wait_remote_all_done(self, timeout=None): @@ -515,8 +648,12 @@ def init(session_id, options=None): Session(options=options).as_global().init_computing(session_id) @staticmethod - def parallelize(data: typing.Iterable, partition: int, include_key: bool, **kwargs) -> CTableABC: - return get_computing_session().parallelize(data, partition=partition, include_key=include_key, **kwargs) + def parallelize( + data: typing.Iterable, partition: int, include_key: bool, **kwargs + ) -> CTableABC: + return get_computing_session().parallelize( + data, partition=partition, include_key=include_key, **kwargs + ) @staticmethod def stop(): diff --git a/python/fate/arch/storage/__init__.py b/python/fate/arch/storage/__init__.py index 17430ce0e6..7692db3fdc 100644 --- a/python/fate/arch/storage/__init__.py +++ b/python/fate/arch/storage/__init__.py @@ -1,7 +1,17 @@ -from fate_arch.storage._types import StorageTableMetaType, StorageEngine -from fate_arch.storage._types import StandaloneStoreType, EggRollStoreType, \ - HDFSStoreType, MySQLStoreType, \ - PathStoreType, HiveStoreType, LinkisHiveStoreType, LocalFSStoreType, ApiStoreType -from fate_arch.storage._types import DEFAULT_ID_DELIMITER, StorageTableOrigin -from fate_arch.storage._session import StorageSessionBase -from fate_arch.storage._table import StorageTableBase, StorageTableMeta +from ._session import StorageSessionBase +from ._table import StorageTableBase, StorageTableMeta +from ._types import ( + DEFAULT_ID_DELIMITER, + ApiStoreType, + EggRollStoreType, + HDFSStoreType, + HiveStoreType, + LinkisHiveStoreType, + LocalFSStoreType, + MySQLStoreType, + PathStoreType, + StandaloneStoreType, + StorageEngine, + StorageTableMetaType, + StorageTableOrigin, +) diff --git a/python/fate/arch/storage/_session.py b/python/fate/arch/storage/_session.py index b211b8d402..67806496ee 100644 --- a/python/fate/arch/storage/_session.py +++ b/python/fate/arch/storage/_session.py @@ -13,19 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os.path import typing -from fate_arch.abc import StorageSessionABC, CTableABC -from fate_arch.common import EngineType, engine_utils -from fate_arch.common.data_utils import default_output_fs_path -from fate_arch.common.log import getLogger -from fate_arch.storage._table import StorageTableMeta -from fate_arch.storage._types import StorageEngine, EggRollStoreType, StandaloneStoreType, HDFSStoreType, HiveStoreType, \ - LinkisHiveStoreType, LocalFSStoreType, PathStoreType, StorageTableOrigin -from fate_arch.relation_ship import Relationship -from fate_arch.common.base_utils import current_timestamp - +from ..abc import CTableABC, StorageSessionABC +from ..common import EngineType, engine_utils +from ..common.base_utils import current_timestamp +from ..common.log import getLogger +from ..storage._table import StorageTableMeta +from ..storage._types import ( + EggRollStoreType, + HDFSStoreType, + HiveStoreType, + LinkisHiveStoreType, + LocalFSStoreType, + PathStoreType, + StandaloneStoreType, + StorageEngine, + StorageTableOrigin, +) LOGGER = getLogger() @@ -36,19 +41,27 @@ def __init__(self, session_id, engine): self._engine = engine def create_table(self, address, name, namespace, partitions=None, **kwargs): - table = self.table(address=address, name=name, namespace=namespace, partitions=partitions, **kwargs) + table = self.table( + address=address, + name=name, + namespace=namespace, + partitions=partitions, + **kwargs, + ) table.create_meta(**kwargs) return table def get_table(self, name, namespace): meta = StorageTableMeta(name=name, namespace=namespace) if meta and meta.exists(): - table = self.table(name=meta.get_name(), - namespace=meta.get_namespace(), - address=meta.get_address(), - partitions=meta.get_partitions(), - store_type=meta.get_store_type(), - options=meta.get_options()) + table = self.table( + name=meta.get_name(), + namespace=meta.get_namespace(), + address=meta.get_address(), + partitions=meta.get_partitions(), + store_type=meta.get_store_type(), + options=meta.get_options(), + ) table.meta = meta return table else: @@ -63,56 +76,105 @@ def get_table_meta(cls, name, namespace): return None @classmethod - def persistent(cls, computing_table: CTableABC, namespace, name, schema=None, - part_of_data=None, engine=None, engine_address=None, - store_type=None, token: typing.Dict = None) -> StorageTableMeta: + def persistent( + cls, + computing_table: CTableABC, + namespace, + name, + schema=None, + part_of_data=None, + engine=None, + engine_address=None, + store_type=None, + token: typing.Dict = None, + ) -> StorageTableMeta: + + from ..relation_ship import Relationship if engine: - if engine != StorageEngine.PATH and engine not in Relationship.Computing.get( - computing_table.engine, {}).get(EngineType.STORAGE, {}).get("support", []): - raise Exception(f"storage engine {engine} not supported with computing engine {computing_table.engine}") + if ( + engine != StorageEngine.PATH + and engine + not in Relationship.Computing.get(computing_table.engine, {}) + .get(EngineType.STORAGE, {}) + .get("support", []) + ): + raise Exception( + f"storage engine {engine} not supported with computing engine {computing_table.engine}" + ) else: - engine = Relationship.Computing.get( - computing_table.engine, - {}).get( - EngineType.STORAGE, - {}).get( - "default", - None) + engine = ( + Relationship.Computing.get(computing_table.engine, {}) + .get(EngineType.STORAGE, {}) + .get("default", None) + ) if not engine: - raise Exception(f"can not found {computing_table.engine} default storage engine") + raise Exception( + f"can not found {computing_table.engine} default storage engine" + ) if engine_address is None: # find engine address from service_conf.yaml - engine_address = engine_utils.get_engines_config_from_conf().get(EngineType.STORAGE, {}).get(engine, {}) + engine_address = ( + engine_utils.get_engines_config_from_conf() + .get(EngineType.STORAGE, {}) + .get(engine, {}) + ) address_dict = engine_address.copy() partitions = computing_table.partitions if engine == StorageEngine.STANDALONE: address_dict.update({"name": name, "namespace": namespace}) - store_type = StandaloneStoreType.ROLLPAIR_LMDB if store_type is None else store_type + store_type = ( + StandaloneStoreType.ROLLPAIR_LMDB if store_type is None else store_type + ) elif engine == StorageEngine.EGGROLL: address_dict.update({"name": name, "namespace": namespace}) - store_type = EggRollStoreType.ROLLPAIR_LMDB if store_type is None else store_type + store_type = ( + EggRollStoreType.ROLLPAIR_LMDB if store_type is None else store_type + ) elif engine == StorageEngine.HIVE: address_dict.update({"database": namespace, "name": f"{name}"}) store_type = HiveStoreType.DEFAULT if store_type is None else store_type elif engine == StorageEngine.LINKIS_HIVE: - address_dict.update({"database": None, "name": f"{namespace}_{name}", - "username": token.get("username", "")}) - store_type = LinkisHiveStoreType.DEFAULT if store_type is None else store_type + address_dict.update( + { + "database": None, + "name": f"{namespace}_{name}", + "username": token.get("username", ""), + } + ) + store_type = ( + LinkisHiveStoreType.DEFAULT if store_type is None else store_type + ) elif engine == StorageEngine.HDFS: + from ..common.data_utils import default_output_fs_path if not address_dict.get("path"): - address_dict.update({"path": default_output_fs_path( - name=name, namespace=namespace, prefix=address_dict.get("path_prefix"))}) + address_dict.update( + { + "path": default_output_fs_path( + name=name, + namespace=namespace, + prefix=address_dict.get("path_prefix"), + ) + } + ) store_type = HDFSStoreType.DISK if store_type is None else store_type elif engine == StorageEngine.LOCALFS: + from ..common.data_utils import default_output_fs_path if not address_dict.get("path"): - address_dict.update({"path": default_output_fs_path( - name=name, namespace=namespace, storage_engine=StorageEngine.LOCALFS)}) + address_dict.update( + { + "path": default_output_fs_path( + name=name, + namespace=namespace, + storage_engine=StorageEngine.LOCALFS, + ) + } + ) store_type = LocalFSStoreType.DISK if store_type is None else store_type elif engine == StorageEngine.PATH: @@ -120,9 +182,13 @@ def persistent(cls, computing_table: CTableABC, namespace, name, schema=None, else: raise RuntimeError(f"{engine} storage is not supported") - address = StorageTableMeta.create_address(storage_engine=engine, address_dict=address_dict) + address = StorageTableMeta.create_address( + storage_engine=engine, address_dict=address_dict + ) schema = schema if schema else {} - computing_table.save(address, schema=schema, partitions=partitions, store_type=store_type) + computing_table.save( + address, schema=schema, partitions=partitions, store_type=store_type + ) table_count = computing_table.count() table_meta = StorageTableMeta(name=name, namespace=namespace, new=True) table_meta.address = address @@ -147,7 +213,9 @@ def destroy(self): try: self.stop() except Exception as e: - LOGGER.warning(f"stop storage session {self._session_id} failed, try to kill", e) + LOGGER.warning( + f"stop storage session {self._session_id} failed, try to kill", e + ) self.kill() def table(self, name, namespace, address, store_type, partitions=None, **kwargs): diff --git a/python/fate/arch/storage/_table.py b/python/fate/arch/storage/_table.py index 4e461c171d..f03064b994 100644 --- a/python/fate/arch/storage/_table.py +++ b/python/fate/arch/storage/_table.py @@ -20,17 +20,18 @@ import peewee -from fate_arch.abc import StorageTableMetaABC, StorageTableABC, AddressABC -from fate_arch.common.base_utils import current_timestamp -from fate_arch.common.log import getLogger -from fate_arch.relation_ship import Relationship -from fate_arch.metastore.db_models import DB, StorageTableMetaModel +from ..abc import AddressABC, StorageTableABC, StorageTableMetaABC +from ..common.base_utils import current_timestamp +from ..common.log import getLogger +from ..metastore.db_models import DB, StorageTableMetaModel LOGGER = getLogger() class StorageTableBase(StorageTableABC): - def __init__(self, name, namespace, address, partitions, options, engine, store_type): + def __init__( + self, name, namespace, address, partitions, options, engine, store_type + ): self._name = name self._namespace = namespace self._address = address @@ -87,22 +88,28 @@ def read_access_time(self): def write_access_time(self): return self._write_access_time - def update_meta(self, - schema=None, - count=None, - part_of_data=None, - description=None, - partitions=None, - **kwargs): - self._meta.update_metas(schema=schema, - count=count, - part_of_data=part_of_data, - description=description, - partitions=partitions, - **kwargs) + def update_meta( + self, + schema=None, + count=None, + part_of_data=None, + description=None, + partitions=None, + **kwargs + ): + self._meta.update_metas( + schema=schema, + count=count, + part_of_data=part_of_data, + description=description, + partitions=partitions, + **kwargs + ) def create_meta(self, **kwargs): - table_meta = StorageTableMeta(name=self._name, namespace=self._namespace, new=True) + table_meta = StorageTableMeta( + name=self._name, namespace=self._namespace, new=True + ) table_meta.set_metas(**kwargs) table_meta.address = self._address table_meta.partitions = self._partitions @@ -145,11 +152,15 @@ def save_as(self, address, name, namespace, partitions=None, **kwargs): return table def _update_read_access_time(self, read_access_time=None): - read_access_time = current_timestamp() if not read_access_time else read_access_time + read_access_time = ( + current_timestamp() if not read_access_time else read_access_time + ) self._meta.update_metas(read_access_time=read_access_time) def _update_write_access_time(self, write_access_time=None): - write_access_time = current_timestamp() if not write_access_time else write_access_time + write_access_time = ( + current_timestamp() if not write_access_time else write_access_time + ) self._meta.update_metas(write_access_time=write_access_time) # to be implemented @@ -168,12 +179,13 @@ def _read(self): def _destroy(self): raise NotImplementedError() - def _save_as(self, address, name, namespace, partitions=None, schema=None, **kwargs): + def _save_as( + self, address, name, namespace, partitions=None, schema=None, **kwargs + ): raise NotImplementedError() class StorageTableMeta(StorageTableMetaABC): - def __init__(self, name, namespace, new=False, create_address=True): self.name = name self.namespace = namespace @@ -210,14 +222,18 @@ def build(self, create_address): for k, v in self.table_meta.__dict__["__data__"].items(): setattr(self, k.lstrip("f_"), v) if create_address: - self.address = self.create_address(storage_engine=self.engine, address_dict=self.address) + self.address = self.create_address( + storage_engine=self.engine, address_dict=self.address + ) def __new__(cls, *args, **kwargs): if not kwargs.get("new", False): name, namespace = kwargs.get("name"), kwargs.get("namespace") if not name or not namespace: return None - tables_meta = cls.query_table_meta(filter_fields=dict(name=name, namespace=namespace)) + tables_meta = cls.query_table_meta( + filter_fields=dict(name=name, namespace=namespace) + ) if not tables_meta: return None self = super().__new__(cls) @@ -239,9 +255,13 @@ def create(self): table_meta.f_schema = {} table_meta.f_part_of_data = [] for k, v in self.to_dict().items(): - attr_name = 'f_%s' % k + attr_name = "f_%s" % k if hasattr(StorageTableMetaModel, attr_name): - setattr(table_meta, attr_name, v if not issubclass(type(v), AddressABC) else v.__dict__) + setattr( + table_meta, + attr_name, + v if not issubclass(type(v), AddressABC) else v.__dict__, + ) try: rows = table_meta.save(force_insert=True) if rows != 1: @@ -268,14 +288,18 @@ def query_table_meta(cls, filter_fields, query_fields=None): filters = [] querys = [] for f_n, f_v in filter_fields.items(): - attr_name = 'f_%s' % f_n + attr_name = "f_%s" % f_n if hasattr(StorageTableMetaModel, attr_name): - filters.append(operator.attrgetter('f_%s' % f_n)(StorageTableMetaModel) == f_v) + filters.append( + operator.attrgetter("f_%s" % f_n)(StorageTableMetaModel) == f_v + ) if query_fields: for f_n in query_fields: - attr_name = 'f_%s' % f_n + attr_name = "f_%s" % f_n if hasattr(StorageTableMetaModel, attr_name): - querys.append(operator.attrgetter('f_%s' % f_n)(StorageTableMetaModel)) + querys.append( + operator.attrgetter("f_%s" % f_n)(StorageTableMetaModel) + ) if filters: if querys: tables_meta = StorageTableMetaModel.select(querys).where(*filters) @@ -287,8 +311,16 @@ def query_table_meta(cls, filter_fields, query_fields=None): return [] @DB.connection_context() - def update_metas(self, schema=None, count=None, part_of_data=None, description=None, partitions=None, - in_serialized=None, **kwargs): + def update_metas( + self, + schema=None, + count=None, + part_of_data=None, + description=None, + partitions=None, + in_serialized=None, + **kwargs + ): meta_info = {} for k, v in locals().items(): if k not in ["self", "kwargs", "meta_info"] and v is not None: @@ -299,20 +331,30 @@ def update_metas(self, schema=None, count=None, part_of_data=None, description=N update_filters = [] primary_keys = StorageTableMetaModel._meta.primary_key.field_names for p_k in primary_keys: - update_filters.append(operator.attrgetter(p_k)(StorageTableMetaModel) == meta_info[p_k.lstrip("f_")]) + update_filters.append( + operator.attrgetter(p_k)(StorageTableMetaModel) + == meta_info[p_k.lstrip("f_")] + ) table_meta = StorageTableMetaModel() update_fields = {} for k, v in meta_info.items(): - attr_name = 'f_%s' % k - if hasattr(StorageTableMetaModel, attr_name) and attr_name not in primary_keys: + attr_name = "f_%s" % k + if ( + hasattr(StorageTableMetaModel, attr_name) + and attr_name not in primary_keys + ): if k == "part_of_data": if len(v) < 100: tmp = v else: tmp = v[:100] - update_fields[operator.attrgetter(attr_name)(StorageTableMetaModel)] = tmp + update_fields[ + operator.attrgetter(attr_name)(StorageTableMetaModel) + ] = tmp else: - update_fields[operator.attrgetter(attr_name)(StorageTableMetaModel)] = v + update_fields[ + operator.attrgetter(attr_name)(StorageTableMetaModel) + ] = v if update_filters: operate = table_meta.update(update_fields).where(*update_filters) else: @@ -325,14 +367,14 @@ def update_metas(self, schema=None, count=None, part_of_data=None, description=N @DB.connection_context() def destroy_metas(self): - StorageTableMetaModel \ - .delete() \ - .where(StorageTableMetaModel.f_name == self.name, - StorageTableMetaModel.f_namespace == self.namespace) \ - .execute() + StorageTableMetaModel.delete().where( + StorageTableMetaModel.f_name == self.name, + StorageTableMetaModel.f_namespace == self.namespace, + ).execute() @classmethod def create_address(cls, storage_engine, address_dict): + from ..relation_ship import Relationship address_class = Relationship.EngineToAddress.get(storage_engine) kwargs = {} for k in address_class.__init__.__code__.co_varnames: diff --git a/python/fate/arch/storage/_types.py b/python/fate/arch/storage/_types.py index ae9fb2f5b3..c8f9c85ba5 100644 --- a/python/fate/arch/storage/_types.py +++ b/python/fate/arch/storage/_types.py @@ -24,55 +24,55 @@ class StorageTableOrigin(object): class StorageEngine(object): - STANDALONE = 'STANDALONE' - EGGROLL = 'EGGROLL' - HDFS = 'HDFS' - MYSQL = 'MYSQL' - SIMPLE = 'SIMPLE' - PATH = 'PATH' - HIVE = 'HIVE' - LINKIS_HIVE = 'LINKIS_HIVE' - LOCALFS = 'LOCALFS' - API = 'API' + STANDALONE = "STANDALONE" + EGGROLL = "EGGROLL" + HDFS = "HDFS" + MYSQL = "MYSQL" + SIMPLE = "SIMPLE" + PATH = "PATH" + HIVE = "HIVE" + LINKIS_HIVE = "LINKIS_HIVE" + LOCALFS = "LOCALFS" + API = "API" class StandaloneStoreType(object): - ROLLPAIR_IN_MEMORY = 'IN_MEMORY' - ROLLPAIR_LMDB = 'LMDB' + ROLLPAIR_IN_MEMORY = "IN_MEMORY" + ROLLPAIR_LMDB = "LMDB" DEFAULT = ROLLPAIR_LMDB class EggRollStoreType(object): - ROLLPAIR_IN_MEMORY = 'IN_MEMORY' - ROLLPAIR_LMDB = 'LMDB' - ROLLPAIR_LEVELDB = 'LEVEL_DB' - ROLLFRAME_FILE = 'ROLL_FRAME_FILE' - ROLLPAIR_ROLLSITE = 'ROLL_SITE' - ROLLPAIR_FILE = 'ROLL_PAIR_FILE' - ROLLPAIR_MMAP = 'ROLL_PAIR_MMAP' - ROLLPAIR_CACHE = 'ROLL_PAIR_CACHE' - ROLLPAIR_QUEUE = 'ROLL_PAIR_QUEUE' + ROLLPAIR_IN_MEMORY = "IN_MEMORY" + ROLLPAIR_LMDB = "LMDB" + ROLLPAIR_LEVELDB = "LEVEL_DB" + ROLLFRAME_FILE = "ROLL_FRAME_FILE" + ROLLPAIR_ROLLSITE = "ROLL_SITE" + ROLLPAIR_FILE = "ROLL_PAIR_FILE" + ROLLPAIR_MMAP = "ROLL_PAIR_MMAP" + ROLLPAIR_CACHE = "ROLL_PAIR_CACHE" + ROLLPAIR_QUEUE = "ROLL_PAIR_QUEUE" DEFAULT = ROLLPAIR_LMDB class HDFSStoreType(object): - RAM_DISK = 'RAM_DISK' - SSD = 'SSD' - DISK = 'DISK' - ARCHIVE = 'ARCHIVE' + RAM_DISK = "RAM_DISK" + SSD = "SSD" + DISK = "DISK" + ARCHIVE = "ARCHIVE" DEFAULT = None class PathStoreType(object): - PICTURE = 'PICTURE' + PICTURE = "PICTURE" class FileStoreType(object): - CSV = 'CSV' + CSV = "CSV" class ApiStoreType(object): - EXTERNAL = 'EXTERNAL' + EXTERNAL = "EXTERNAL" class MySQLStoreType(object): @@ -92,10 +92,10 @@ class LinkisHiveStoreType(object): class LocalFSStoreType(object): - RAM_DISK = 'RAM_DISK' - SSD = 'SSD' - DISK = 'DISK' - ARCHIVE = 'ARCHIVE' + RAM_DISK = "RAM_DISK" + SSD = "SSD" + DISK = "DISK" + ARCHIVE = "ARCHIVE" DEFAULT = None diff --git a/python/fate/arch/storage/_utils.py b/python/fate/arch/storage/_utils.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate/arch/storage/api/__init__.py b/python/fate/arch/storage/api/__init__.py index 19fb78f24f..01f9bd810d 100644 --- a/python/fate/arch/storage/api/__init__.py +++ b/python/fate/arch/storage/api/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.api._table import StorageTable -from fate_arch.storage.api._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/api/_session.py b/python/fate/arch/storage/api/_session.py index cf673c1fb0..dccea39d12 100644 --- a/python/fate/arch/storage/api/_session.py +++ b/python/fate/arch/storage/api/_session.py @@ -14,29 +14,47 @@ # limitations under the License. # import os -import shutil -import traceback -from fate_arch.common import file_utils -from fate_arch.storage import StorageSessionBase, StorageEngine -from fate_arch.abc import AddressABC -from fate_arch.common.address import ApiAddress +from ...abc import AddressABC +from ...common import file_utils +from ...common.address import ApiAddress +from ...storage import StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.PATH) - self.base_dir = os.path.join(file_utils.get_project_base_directory(), "api_data", session_id) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.PATH + ) + self.base_dir = os.path.join( + file_utils.get_project_base_directory(), "api_data", session_id + ) - def table(self, address: AddressABC, name, namespace, partitions, store_type=None, options=None, **kwargs): + def table( + self, + address: AddressABC, + name, + namespace, + partitions, + store_type=None, + options=None, + **kwargs, + ): if isinstance(address, ApiAddress): - from fate_arch.storage.api._table import StorageTable - return StorageTable(path=os.path.join(self.base_dir, namespace, name), - address=address, - name=name, - namespace=namespace, - partitions=partitions, store_type=store_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with api storage") + from ._table import StorageTable + + return StorageTable( + path=os.path.join(self.base_dir, namespace, name), + address=address, + name=name, + namespace=namespace, + partitions=partitions, + store_type=store_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with api storage" + ) def cleanup(self, name, namespace): # path = os.path.join(self.base_dir, namespace, name) diff --git a/python/fate/arch/storage/api/_table.py b/python/fate/arch/storage/api/_table.py index c2c239d220..ef212bae38 100644 --- a/python/fate/arch/storage/api/_table.py +++ b/python/fate/arch/storage/api/_table.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os from contextlib import closing import requests -import os -from fate_arch.common.log import getLogger -from fate_arch.storage import StorageEngine, ApiStoreType -from fate_arch.storage import StorageTableBase +from ...common.log import getLogger +from ...storage import ApiStoreType, StorageEngine, StorageTableBase LOGGER = getLogger() @@ -51,11 +50,17 @@ def __init__( def _collect(self, **kwargs) -> list: self.request = getattr(requests, self.address.method.lower(), None) id_delimiter = self._meta.get_id_delimiter() - with closing(self.request(url=self.address.url, json=self.address.body, headers=self.address.header, - stream=True)) as response: + with closing( + self.request( + url=self.address.url, + json=self.address.body, + headers=self.address.header, + stream=True, + ) + ) as response: if response.status_code == 200: os.makedirs(os.path.dirname(self.path), exist_ok=True) - with open(self.path, 'wb') as fw: + with open(self.path, "wb") as fw: for chunk in response.iter_content(1024): if chunk: fw.write(chunk) @@ -66,10 +71,14 @@ def _collect(self, **kwargs) -> list: for line in lines: self.data_count += 1 id = line.split(id_delimiter)[0] - feature = id_delimiter.join(line.split(id_delimiter)[1:]) + feature = id_delimiter.join( + line.split(id_delimiter)[1:] + ) yield id, feature else: - _, self._meta = self._meta.update_metas(count=self.data_count) + _, self._meta = self._meta.update_metas( + count=self.data_count + ) break else: raise Exception(response.status_code, response.text) diff --git a/python/fate/arch/storage/eggroll/__init__.py b/python/fate/arch/storage/eggroll/__init__.py index d78427e173..01f9bd810d 100644 --- a/python/fate/arch/storage/eggroll/__init__.py +++ b/python/fate/arch/storage/eggroll/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.eggroll._table import StorageTable -from fate_arch.storage.eggroll._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/eggroll/_session.py b/python/fate/arch/storage/eggroll/_session.py index 7c1b9802bd..3b58ec700a 100644 --- a/python/fate/arch/storage/eggroll/_session.py +++ b/python/fate/arch/storage/eggroll/_session.py @@ -14,30 +14,52 @@ # limitations under the License. # -from fate_arch.storage import StorageSessionBase, StorageEngine, EggRollStoreType -from fate_arch.abc import AddressABC -from fate_arch.common.address import EggRollAddress from eggroll.core.session import session_init from eggroll.roll_pair.roll_pair import RollPairContext +from ...abc import AddressABC +from ...common.address import EggRollAddress +from ...storage import EggRollStoreType, StorageEngine, StorageSessionBase + class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.EGGROLL) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.EGGROLL + ) self._options = options if options else {} - self._options['eggroll.session.deploy.mode'] = "cluster" - self._rp_session = session_init(session_id=self._session_id, options=self._options) + self._options["eggroll.session.deploy.mode"] = "cluster" + self._rp_session = session_init( + session_id=self._session_id, options=self._options + ) self._rpc = RollPairContext(session=self._rp_session) self._session_id = self._rp_session.get_session_id() - def table(self, name, namespace, - address: AddressABC, partitions, - store_type: EggRollStoreType = EggRollStoreType.ROLLPAIR_LMDB, options=None, **kwargs): + def table( + self, + name, + namespace, + address: AddressABC, + partitions, + store_type: EggRollStoreType = EggRollStoreType.ROLLPAIR_LMDB, + options=None, + **kwargs, + ): if isinstance(address, EggRollAddress): - from fate_arch.storage.eggroll._table import StorageTable - return StorageTable(context=self._rpc, name=name, namespace=namespace, address=address, - partitions=partitions, store_type=store_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll storage") + from ._table import StorageTable + + return StorageTable( + context=self._rpc, + name=name, + namespace=namespace, + address=address, + partitions=partitions, + store_type=store_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll storage" + ) def cleanup(self, name, namespace): self._rpc.cleanup(name=name, namespace=namespace) diff --git a/python/fate/arch/storage/eggroll/_table.py b/python/fate/arch/storage/eggroll/_table.py index ad096827f6..7b2d5130bb 100644 --- a/python/fate/arch/storage/eggroll/_table.py +++ b/python/fate/arch/storage/eggroll/_table.py @@ -15,7 +15,8 @@ # from typing import Iterable -from fate_arch.storage import StorageTableBase, StorageEngine, EggRollStoreType + +from ...storage import EggRollStoreType, StorageEngine, StorageTableBase class StorageTable(StorageTableBase): @@ -54,7 +55,7 @@ def _save_as(self, address, name, namespace, partitions=None, **kwargs): address=address, partitions=partitions, name=name, - namespace=namespace + namespace=namespace, ) return table diff --git a/python/fate/arch/storage/hdfs/__init__.py b/python/fate/arch/storage/hdfs/__init__.py index 428baf0a05..01f9bd810d 100644 --- a/python/fate/arch/storage/hdfs/__init__.py +++ b/python/fate/arch/storage/hdfs/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.hdfs._table import StorageTable -from fate_arch.storage.hdfs._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/hdfs/_session.py b/python/fate/arch/storage/hdfs/_session.py index 923e5d6263..a51a11b70d 100644 --- a/python/fate/arch/storage/hdfs/_session.py +++ b/python/fate/arch/storage/hdfs/_session.py @@ -14,21 +14,41 @@ # limitations under the License. # -from fate_arch.storage import StorageSessionBase, StorageEngine -from fate_arch.abc import AddressABC -from fate_arch.common.address import HDFSAddress +from ...abc import AddressABC +from ...common.address import HDFSAddress +from ...storage import StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.HDFS) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.HDFS + ) - def table(self, address: AddressABC, name, namespace, partitions, store_type=None, options=None, **kwargs): + def table( + self, + address: AddressABC, + name, + namespace, + partitions, + store_type=None, + options=None, + **kwargs, + ): if isinstance(address, HDFSAddress): - from fate_arch.storage.hdfs._table import StorageTable - return StorageTable(address=address, name=name, namespace=namespace, - partitions=partitions, store_type=store_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with hdfs storage") + from ._table import StorageTable + + return StorageTable( + address=address, + name=name, + namespace=namespace, + partitions=partitions, + store_type=store_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with hdfs storage" + ) def cleanup(self, name, namespace): pass diff --git a/python/fate/arch/storage/hdfs/_table.py b/python/fate/arch/storage/hdfs/_table.py index f7cfe6ad84..50ddc65af7 100644 --- a/python/fate/arch/storage/hdfs/_table.py +++ b/python/fate/arch/storage/hdfs/_table.py @@ -18,23 +18,22 @@ from pyarrow import fs -from fate_arch.common import hdfs_utils -from fate_arch.common.log import getLogger -from fate_arch.storage import StorageEngine, HDFSStoreType -from fate_arch.storage import StorageTableBase +from ...common import hdfs_utils +from ...common.log import getLogger +from ...storage import HDFSStoreType, StorageEngine, StorageTableBase LOGGER = getLogger() class StorageTable(StorageTableBase): def __init__( - self, - address=None, - name: str = None, - namespace: str = None, - partitions: int = 1, - store_type: HDFSStoreType = HDFSStoreType.DISK, - options=None, + self, + address=None, + name: str = None, + namespace: str = None, + partitions: int = 1, + store_type: HDFSStoreType = HDFSStoreType.DISK, + options=None, ): super(StorageTable, self).__init__( name=name, @@ -58,7 +57,7 @@ def check_address(self): return self._exist() def _put_all( - self, kv_list: Iterable, append=True, assume_file_exist=False, **kwargs + self, kv_list: Iterable, append=True, assume_file_exist=False, **kwargs ): LOGGER.info(f"put in hdfs file: {self.file_path}") if append and (assume_file_exist or self._exist()): @@ -97,9 +96,7 @@ def _count(self): count += 1 return count - def _save_as( - self, address, partitions=None, name=None, namespace=None, **kwargs - ): + def _save_as(self, address, partitions=None, name=None, namespace=None, **kwargs): self._hdfs_client.copy_file(src=self.file_path, dst=address.path) table = StorageTable( address=address, @@ -144,8 +141,8 @@ def _as_generator(self): file_info.is_file ), f"{self.path} is directory contains a subdirectory: {file_info.path}" with io.TextIOWrapper( - buffer=self._hdfs_client.open_input_stream(file_info.path), - encoding="utf-8", + buffer=self._hdfs_client.open_input_stream(file_info.path), + encoding="utf-8", ) as reader: for line in reader: yield line @@ -185,6 +182,8 @@ def _read_buffer_lines(self, path=None): offset += len(buffer_block[:end_index]) def _read_lines(self, buffer_block): - with io.TextIOWrapper(buffer=io.BytesIO(buffer_block), encoding="utf-8") as reader: + with io.TextIOWrapper( + buffer=io.BytesIO(buffer_block), encoding="utf-8" + ) as reader: for line in reader: yield line diff --git a/python/fate/arch/storage/hive/__init__.py b/python/fate/arch/storage/hive/__init__.py index cf92370dda..39d4b8ca70 100644 --- a/python/fate/arch/storage/hive/__init__.py +++ b/python/fate/arch/storage/hive/__init__.py @@ -1,4 +1,4 @@ -from fate_arch.storage.hive._table import StorageTable -from fate_arch.storage.hive._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/hive/_session.py b/python/fate/arch/storage/hive/_session.py index 6b596354fc..b337253a96 100644 --- a/python/fate/arch/storage/hive/_session.py +++ b/python/fate/arch/storage/hive/_session.py @@ -17,21 +17,32 @@ from impala.dbapi import connect -from fate_arch.common.address import HiveAddress -from fate_arch.storage import StorageSessionBase, StorageEngine, HiveStoreType -from fate_arch.abc import AddressABC +from ...abc import AddressABC +from ...common.address import HiveAddress +from ...storage import HiveStoreType, StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.HIVE) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.HIVE + ) self._db_con = {} - def table(self, name, namespace, address: AddressABC, partitions, - storage_type: HiveStoreType = HiveStoreType.DEFAULT, options=None, **kwargs): + def table( + self, + name, + namespace, + address: AddressABC, + partitions, + storage_type: HiveStoreType = HiveStoreType.DEFAULT, + options=None, + **kwargs, + ): if isinstance(address, HiveAddress): - from fate_arch.storage.hive._table import StorageTable + from ...storage.hive._table import StorageTable + address_key = HiveAddress( host=address.host, username=None, @@ -39,23 +50,35 @@ def table(self, name, namespace, address: AddressABC, partitions, database=address.database, auth_mechanism=None, password=None, - name=None) + name=None, + ) if address_key in self._db_con: con, cur = self._db_con[address_key] else: self._create_db_if_not_exists(address) - con = connect(host=address.host, - port=address.port, - database=address.database, - auth_mechanism=address.auth_mechanism, - password=address.password, - user=address.username - ) + con = connect( + host=address.host, + port=address.port, + database=address.database, + auth_mechanism=address.auth_mechanism, + password=address.password, + user=address.username, + ) cur = con.cursor() self._db_con[address_key] = (con, cur) - return StorageTable(cur=cur, con=con, address=address, name=name, namespace=namespace, - storage_type=storage_type, partitions=partitions, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll storage") + return StorageTable( + cur=cur, + con=con, + address=address, + name=name, + namespace=namespace, + storage_type=storage_type, + partitions=partitions, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll storage" + ) def cleanup(self, name, namespace): pass @@ -74,14 +97,17 @@ def kill(self): return self.stop() def _create_db_if_not_exists(self, address): - connection = connect(host=address.host, - port=address.port, - user=address.username, - auth_mechanism=address.auth_mechanism, - password=address.password - ) + connection = connect( + host=address.host, + port=address.port, + user=address.username, + auth_mechanism=address.auth_mechanism, + password=address.password, + ) with connection: with connection.cursor() as cursor: - cursor.execute("create database if not exists {}".format(address.database)) - print('create db {} success'.format(address.database)) + cursor.execute( + "create database if not exists {}".format(address.database) + ) + print("create db {} success".format(address.database)) connection.commit() diff --git a/python/fate/arch/storage/hive/_table.py b/python/fate/arch/storage/hive/_table.py index bb577930c2..4929498ad2 100644 --- a/python/fate/arch/storage/hive/_table.py +++ b/python/fate/arch/storage/hive/_table.py @@ -16,10 +16,9 @@ import os import uuid -from fate_arch.common import hive_utils -from fate_arch.common.file_utils import get_project_base_directory -from fate_arch.storage import StorageEngine, HiveStoreType -from fate_arch.storage import StorageTableBase +from ...common import hive_utils +from ...common.file_utils import get_project_base_directory +from ...storage import HiveStoreType, StorageEngine, StorageTableBase class StorageTable(StorageTableBase): @@ -60,7 +59,7 @@ def execute(self, sql, select=True): return result def _count(self, **kwargs): - sql = 'select count(*) from {}'.format(self._address.name) + sql = "select count(*) from {}".format(self._address.name) try: self._cur.execute(sql) self._con.commit() @@ -87,23 +86,29 @@ def _read(self) -> list: def _put_all(self, kv_list, **kwargs): id_name, feature_name_list, id_delimiter = self.get_id_feature_name() - create_table = "create table if not exists {}(k varchar(128) NOT NULL, v string) row format delimited fields terminated by" \ - " '{}'".format(self._address.name, id_delimiter) + create_table = ( + "create table if not exists {}(k varchar(128) NOT NULL, v string) row format delimited fields terminated by" + " '{}'".format(self._address.name, id_delimiter) + ) self._cur.execute(create_table) # load local file or hdfs file - temp_path = os.path.join(get_project_base_directory(), 'temp_data', uuid.uuid1().hex) + temp_path = os.path.join( + get_project_base_directory(), "temp_data", uuid.uuid1().hex + ) os.makedirs(os.path.dirname(temp_path), exist_ok=True) - with open(temp_path, 'w') as f: + with open(temp_path, "w") as f: for k, v in kv_list: f.write(hive_utils.serialize_line(k, v)) - sql = "load data local inpath '{}' into table {}".format(temp_path, self._address.name) + sql = "load data local inpath '{}' into table {}".format( + temp_path, self._address.name + ) self._cur.execute(sql) self._con.commit() os.remove(temp_path) def get_id_feature_name(self): - id = self.meta.get_schema().get('sid', 'id') - header = self.meta.get_schema().get('header') + id = self.meta.get_schema().get("sid", "id") + header = self.meta.get_schema().get("header") id_delimiter = self.meta.get_id_delimiter() if header: if isinstance(header, str): @@ -121,13 +126,17 @@ def _destroy(self): return self.execute(sql) def _save_as(self, address, name, namespace, partitions=None, **kwargs): - sql = "create table {}.{} like {}.{};".format(namespace, name, self._namespace, self._name) + sql = "create table {}.{} like {}.{};".format( + namespace, name, self._namespace, self._name + ) return self.execute(sql) def check_address(self): schema = self.meta.get_schema() if schema: - sql = 'SELECT {},{} FROM {}'.format(schema.get('sid'), schema.get('header'), self._address.name) + sql = "SELECT {},{} FROM {}".format( + schema.get("sid"), schema.get("header"), self._address.name + ) feature_data = self.execute(sql) for feature in feature_data: if feature: @@ -136,11 +145,11 @@ def check_address(self): @staticmethod def get_meta_header(feature_name_list): - create_features = '' + create_features = "" feature_list = [] feature_size = "varchar(255)" for feature_name in feature_name_list: - create_features += '{} {},'.format(feature_name, feature_size) + create_features += "{} {},".format(feature_name, feature_size) feature_list.append(feature_name) return create_features, feature_list diff --git a/python/fate/arch/storage/linkis_hive/__init__.py b/python/fate/arch/storage/linkis_hive/__init__.py index 83c287bca4..39d4b8ca70 100644 --- a/python/fate/arch/storage/linkis_hive/__init__.py +++ b/python/fate/arch/storage/linkis_hive/__init__.py @@ -1,4 +1,4 @@ -from fate_arch.storage.linkis_hive._table import StorageTable -from fate_arch.storage.linkis_hive._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/linkis_hive/_session.py b/python/fate/arch/storage/linkis_hive/_session.py index 09ca1b006a..6f2cb0f40b 100644 --- a/python/fate/arch/storage/linkis_hive/_session.py +++ b/python/fate/arch/storage/linkis_hive/_session.py @@ -13,31 +13,45 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.common.address import LinkisHiveAddress -from fate_arch.storage import StorageSessionBase, StorageEngine, LinkisHiveStoreType -from fate_arch.abc import AddressABC +from ...abc import AddressABC +from ...common.address import LinkisHiveAddress +from ...storage import LinkisHiveStoreType, StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.LINKIS_HIVE) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.LINKIS_HIVE + ) self.con = None self.cur = None self.address = None - def table(self, name, namespace, address: AddressABC, partitions, - storage_type: LinkisHiveStoreType = LinkisHiveStoreType.DEFAULT, options=None, **kwargs): + def table( + self, + name, + namespace, + address: AddressABC, + partitions, + storage_type: LinkisHiveStoreType = LinkisHiveStoreType.DEFAULT, + options=None, + **kwargs, + ): self.address = address if isinstance(address, LinkisHiveAddress): - from fate_arch.storage.linkis_hive._table import StorageTable + from ...storage.linkis_hive._table import StorageTable + return StorageTable( address=address, name=name, namespace=namespace, storage_type=storage_type, partitions=partitions, - options=options) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll storage") + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll storage" + ) def cleanup(self, name, namespace): pass diff --git a/python/fate/arch/storage/linkis_hive/_table.py b/python/fate/arch/storage/linkis_hive/_table.py index 00e5004952..470cdb463d 100644 --- a/python/fate/arch/storage/linkis_hive/_table.py +++ b/python/fate/arch/storage/linkis_hive/_table.py @@ -17,14 +17,8 @@ import requests -from fate_arch.storage import StorageEngine, LinkisHiveStoreType -from fate_arch.storage import StorageTableBase -from fate_arch.storage.linkis_hive._settings import ( - Token_Code, - Token_User, - STATUS_URI, - EXECUTE_URI, -) +from ...storage import LinkisHiveStoreType, StorageEngine, StorageTableBase +from ._settings import EXECUTE_URI, STATUS_URI, Token_Code, Token_User class StorageTable(StorageTableBase): diff --git a/python/fate/arch/storage/localfs/__init__.py b/python/fate/arch/storage/localfs/__init__.py index 2f5a9c8339..39d4b8ca70 100644 --- a/python/fate/arch/storage/localfs/__init__.py +++ b/python/fate/arch/storage/localfs/__init__.py @@ -1,4 +1,4 @@ -from fate_arch.storage.localfs._table import StorageTable -from fate_arch.storage.localfs._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/localfs/_session.py b/python/fate/arch/storage/localfs/_session.py index fff465074b..d0ffbee2d1 100644 --- a/python/fate/arch/storage/localfs/_session.py +++ b/python/fate/arch/storage/localfs/_session.py @@ -14,21 +14,41 @@ # limitations under the License. # -from fate_arch.storage import StorageSessionBase, StorageEngine -from fate_arch.abc import AddressABC -from fate_arch.common.address import LocalFSAddress +from ...abc import AddressABC +from ...common.address import LocalFSAddress +from ...storage import StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.LOCALFS) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.LOCALFS + ) - def table(self, address: AddressABC, name, namespace, partitions, storage_type=None, options=None, **kwargs): + def table( + self, + address: AddressABC, + name, + namespace, + partitions, + storage_type=None, + options=None, + **kwargs, + ): if isinstance(address, LocalFSAddress): - from fate_arch.storage.localfs._table import StorageTable - return StorageTable(address=address, name=name, namespace=namespace, - partitions=partitions, storage_type=storage_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with hdfs storage") + from ._table import StorageTable + + return StorageTable( + address=address, + name=name, + namespace=namespace, + partitions=partitions, + storage_type=storage_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with hdfs storage" + ) def cleanup(self, name, namespace): pass diff --git a/python/fate/arch/storage/localfs/_table.py b/python/fate/arch/storage/localfs/_table.py index beb4ffa0e5..d90b9245c6 100644 --- a/python/fate/arch/storage/localfs/_table.py +++ b/python/fate/arch/storage/localfs/_table.py @@ -15,15 +15,13 @@ # import io -import os from typing import Iterable from pyarrow import fs -from fate_arch.common import hdfs_utils -from fate_arch.common.log import getLogger -from fate_arch.storage import StorageEngine, LocalFSStoreType -from fate_arch.storage import StorageTableBase +from ...common import hdfs_utils +from ...common.log import getLogger +from ...storage import LocalFSStoreType, StorageEngine, StorageTableBase LOGGER = getLogger() @@ -99,9 +97,7 @@ def _count(self): count += 1 return count - def _save_as( - self, address, partitions=None, name=None, namespace=None, **kwargs - ): + def _save_as(self, address, partitions=None, name=None, namespace=None, **kwargs): self._local_fs_client.copy_file(src=self.path, dst=address.path) return StorageTable( address=address, @@ -130,7 +126,9 @@ def _as_generator(self): selector = fs.FileSelector(self.path) file_infos = self._local_fs_client.get_file_info(selector) for file_info in file_infos: - if file_info.base_name.startswith(".") or file_info.base_name.startswith("_"): + if file_info.base_name.startswith( + "." + ) or file_info.base_name.startswith("_"): continue assert ( file_info.is_file @@ -179,6 +177,8 @@ def _read_buffer_lines(self, path=None): offset += len(buffer_block[:end_index]) def _read_lines(self, buffer_block): - with io.TextIOWrapper(buffer=io.BytesIO(buffer_block), encoding="utf-8") as reader: + with io.TextIOWrapper( + buffer=io.BytesIO(buffer_block), encoding="utf-8" + ) as reader: for line in reader: yield line diff --git a/python/fate/arch/storage/mysql/__init__.py b/python/fate/arch/storage/mysql/__init__.py index 4808a3933f..01f9bd810d 100644 --- a/python/fate/arch/storage/mysql/__init__.py +++ b/python/fate/arch/storage/mysql/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.mysql._table import StorageTable -from fate_arch.storage.mysql._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/mysql/_session.py b/python/fate/arch/storage/mysql/_session.py index 8a011e563c..006ee1c715 100644 --- a/python/fate/arch/storage/mysql/_session.py +++ b/python/fate/arch/storage/mysql/_session.py @@ -17,44 +17,69 @@ import pymysql -from fate_arch.storage import StorageSessionBase, StorageEngine, MySQLStoreType -from fate_arch.abc import AddressABC -from fate_arch.common.address import MysqlAddress +from ...abc import AddressABC +from ...common.address import MysqlAddress +from ...storage import MySQLStoreType, StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.MYSQL) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.MYSQL + ) self._db_con = {} - def table(self, name, namespace, address: AddressABC, partitions, - store_type: MySQLStoreType = MySQLStoreType.InnoDB, options=None, **kwargs): + def table( + self, + name, + namespace, + address: AddressABC, + partitions, + store_type: MySQLStoreType = MySQLStoreType.InnoDB, + options=None, + **kwargs, + ): if isinstance(address, MysqlAddress): - from fate_arch.storage.mysql._table import StorageTable - address_key = MysqlAddress(user=None, - passwd=None, - host=address.host, - port=address.port, - db=address.db, - name=None) + from ...storage.mysql._table import StorageTable + + address_key = MysqlAddress( + user=None, + passwd=None, + host=address.host, + port=address.port, + db=address.db, + name=None, + ) if address_key in self._db_con: con, cur = self._db_con[address_key] else: self._create_db_if_not_exists(address) - con = pymysql.connect(host=address.host, - user=address.user, - passwd=address.passwd, - port=address.port, - db=address.db) + con = pymysql.connect( + host=address.host, + user=address.user, + passwd=address.passwd, + port=address.port, + db=address.db, + ) cur = con.cursor() self._db_con[address_key] = (con, cur) - return StorageTable(cur=cur, con=con, address=address, name=name, namespace=namespace, - store_type=store_type, partitions=partitions, options=options) + return StorageTable( + cur=cur, + con=con, + address=address, + name=name, + namespace=namespace, + store_type=store_type, + partitions=partitions, + options=options, + ) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll storage") + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll storage" + ) def cleanup(self, name, namespace): pass @@ -73,13 +98,15 @@ def kill(self): return self.stop() def _create_db_if_not_exists(self, address): - connection = pymysql.connect(host=address.host, - user=address.user, - password=address.passwd, - port=address.port) + connection = pymysql.connect( + host=address.host, + user=address.user, + password=address.passwd, + port=address.port, + ) with connection: with connection.cursor() as cursor: cursor.execute("create database if not exists {}".format(address.db)) - print('create db {} success'.format(address.db)) + print("create db {} success".format(address.db)) connection.commit() diff --git a/python/fate/arch/storage/mysql/_table.py b/python/fate/arch/storage/mysql/_table.py index 3cf0af3406..e602c9329b 100644 --- a/python/fate/arch/storage/mysql/_table.py +++ b/python/fate/arch/storage/mysql/_table.py @@ -14,8 +14,7 @@ # limitations under the License. # -from fate_arch.storage import StorageEngine, MySQLStoreType -from fate_arch.storage import StorageTableBase +from ...storage import MySQLStoreType, StorageEngine, StorageTableBase class StorageTable(StorageTableBase): @@ -50,9 +49,7 @@ def check_address(self): schema.get("sid"), schema.get("header"), self._address.name ) else: - sql = "SELECT {} FROM {}".format( - schema.get("sid"), self._address.name - ) + sql = "SELECT {} FROM {}".format(schema.get("sid"), self._address.name) feature_data = self.execute(sql) for feature in feature_data: if feature: @@ -115,7 +112,9 @@ def _destroy(self): self._con.commit() def _save_as(self, address, name, namespace, partitions=None, **kwargs): - sql = "create table {}.{} select * from {};".format(namespace, name, self._address.name) + sql = "create table {}.{} select * from {};".format( + namespace, name, self._address.name + ) self._cur.execute(sql) self._con.commit() diff --git a/python/fate/arch/storage/path/__init__.py b/python/fate/arch/storage/path/__init__.py index 4f77ff19c5..01f9bd810d 100644 --- a/python/fate/arch/storage/path/__init__.py +++ b/python/fate/arch/storage/path/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.path._table import StorageTable -from fate_arch.storage.path._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/path/_session.py b/python/fate/arch/storage/path/_session.py index df384a1be4..075d53eed1 100644 --- a/python/fate/arch/storage/path/_session.py +++ b/python/fate/arch/storage/path/_session.py @@ -14,21 +14,41 @@ # limitations under the License. # -from fate_arch.storage import StorageSessionBase, StorageEngine -from fate_arch.abc import AddressABC -from fate_arch.common.address import PathAddress +from ...abc import AddressABC +from ...common.address import PathAddress +from ...storage import StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.PATH) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.PATH + ) - def table(self, address: AddressABC, name, namespace, partitions, store_type=None, options=None, **kwargs): + def table( + self, + address: AddressABC, + name, + namespace, + partitions, + store_type=None, + options=None, + **kwargs, + ): if isinstance(address, PathAddress): - from fate_arch.storage.path._table import StorageTable - return StorageTable(address=address, name=name, namespace=namespace, - partitions=partitions, store_type=store_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with hdfs storage") + from ...storage.path._table import StorageTable + + return StorageTable( + address=address, + name=name, + namespace=namespace, + partitions=partitions, + store_type=store_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with hdfs storage" + ) def cleanup(self, name, namespace): pass diff --git a/python/fate/arch/storage/path/_table.py b/python/fate/arch/storage/path/_table.py index 9f7e9ed49f..9607be6365 100644 --- a/python/fate/arch/storage/path/_table.py +++ b/python/fate/arch/storage/path/_table.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Iterable -from fate_arch.common import path_utils -from fate_arch.common.log import getLogger -from fate_arch.storage import StorageEngine, PathStoreType -from fate_arch.storage import StorageTableBase +from ...common import path_utils +from ...common.log import getLogger +from ...storage import PathStoreType, StorageEngine, StorageTableBase LOGGER = getLogger() diff --git a/python/fate/arch/storage/standalone/__init__.py b/python/fate/arch/storage/standalone/__init__.py index 88e5d1b3a0..01f9bd810d 100644 --- a/python/fate/arch/storage/standalone/__init__.py +++ b/python/fate/arch/storage/standalone/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.storage.standalone._table import StorageTable -from fate_arch.storage.standalone._session import StorageSession +from ._session import StorageSession +from ._table import StorageTable __all__ = ["StorageTable", "StorageSession"] diff --git a/python/fate/arch/storage/standalone/_session.py b/python/fate/arch/storage/standalone/_session.py index 8f7f9464df..139af98742 100644 --- a/python/fate/arch/storage/standalone/_session.py +++ b/python/fate/arch/storage/standalone/_session.py @@ -13,24 +13,45 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate_arch.abc import AddressABC -from fate_arch.common.address import StandaloneAddress -from fate_arch.storage import StorageSessionBase, StorageEngine -from fate_arch._standalone import Session +from ..._standalone import Session +from ...abc import AddressABC +from ...common.address import StandaloneAddress +from ...storage import StorageEngine, StorageSessionBase class StorageSession(StorageSessionBase): def __init__(self, session_id, options=None): - super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.STANDALONE) + super(StorageSession, self).__init__( + session_id=session_id, engine=StorageEngine.STANDALONE + ) self._options = options if options else {} self._session = Session(session_id=self._session_id) - def table(self, address: AddressABC, name, namespace, partitions, store_type=None, options=None, **kwargs): + def table( + self, + address: AddressABC, + name, + namespace, + partitions, + store_type=None, + options=None, + **kwargs, + ): if isinstance(address, StandaloneAddress): - from fate_arch.storage.standalone._table import StorageTable - return StorageTable(session=self._session, name=name, namespace=namespace, address=address, - partitions=partitions, store_type=store_type, options=options) - raise NotImplementedError(f"address type {type(address)} not supported with standalone storage") + from ...storage.standalone._table import StorageTable + + return StorageTable( + session=self._session, + name=name, + namespace=namespace, + address=address, + partitions=partitions, + store_type=store_type, + options=options, + ) + raise NotImplementedError( + f"address type {type(address)} not supported with standalone storage" + ) def cleanup(self, name, namespace): self._session.cleanup(name=name, namespace=namespace) diff --git a/python/fate/arch/storage/standalone/_table.py b/python/fate/arch/storage/standalone/_table.py index 2ccf6936e6..2ea20e771a 100644 --- a/python/fate/arch/storage/standalone/_table.py +++ b/python/fate/arch/storage/standalone/_table.py @@ -15,9 +15,8 @@ # from typing import Iterable -from fate_arch._standalone import Session -from fate_arch.storage import StorageEngine, StandaloneStoreType -from fate_arch.storage import StorageTableBase +from ..._standalone import Session +from ...storage import StandaloneStoreType, StorageEngine, StorageTableBase class StorageTable(StorageTableBase): diff --git a/python/fate/arch/tests/computing/__init__.py b/python/fate/arch/tests/computing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate/arch/tests/computing/spark_test.py b/python/fate/arch/tests/computing/spark_test.py index 578959614d..ae576fb4d3 100644 --- a/python/fate/arch/tests/computing/spark_test.py +++ b/python/fate/arch/tests/computing/spark_test.py @@ -14,7 +14,8 @@ # limitations under the License. # from pyspark import SparkContext -sc = SparkContext('local', 'test') + +sc = SparkContext("local", "test") a = [] for i in range(10): a.append((i, str(i))) diff --git a/python/fate/arch/tests/storage/__init__.py b/python/fate/arch/tests/storage/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate/arch/tests/storage/metastore_test.py b/python/fate/arch/tests/storage/metastore_test.py index c9dd94a6cb..2e2e0bfc6d 100644 --- a/python/fate/arch/tests/storage/metastore_test.py +++ b/python/fate/arch/tests/storage/metastore_test.py @@ -14,20 +14,48 @@ # limitations under the License. # import unittest -from fate_arch.metastore import base_model + +from ...metastore import base_model class TestBaseModel(unittest.TestCase): def test_auto_date_timestamp_field(self): self.assertEqual( - base_model.auto_date_timestamp_field(), { - 'write_access_time', 'create_time', 'read_access_time', 'end_time', 'update_time', 'start_time'}) + base_model.auto_date_timestamp_field(), + { + "write_access_time", + "create_time", + "read_access_time", + "end_time", + "update_time", + "start_time", + }, + ) def test(self): - from peewee import IntegerField, FloatField, AutoField, BigAutoField, BigIntegerField, BitField - from peewee import CharField, TextField, BooleanField, BigBitField - from fate_arch.metastore.base_model import JSONField, LongTextField - for f in {IntegerField, FloatField, AutoField, BigAutoField, BigIntegerField, BitField}: + from peewee import ( + AutoField, + BigAutoField, + BigBitField, + BigIntegerField, + BitField, + BooleanField, + CharField, + FloatField, + IntegerField, + TextField, + ) + + from ...metastore.base_model import JSONField, LongTextField + + for f in { + IntegerField, + FloatField, + AutoField, + BigAutoField, + BigIntegerField, + BitField, + }: self.assertEqual(base_model.is_continuous_field(f), True) for f in {CharField, TextField, BooleanField, BigBitField}: self.assertEqual(base_model.is_continuous_field(f), False) @@ -35,5 +63,5 @@ def test(self): self.assertEqual(base_model.is_continuous_field(f), False) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/python/fate/arch/tests/test_arch_api.py b/python/fate/arch/tests/test_arch_api.py index ff749e08d9..545ff9a637 100644 --- a/python/fate/arch/tests/test_arch_api.py +++ b/python/fate/arch/tests/test_arch_api.py @@ -16,8 +16,8 @@ import uuid import numpy as np -from fate_arch import session +from .. import session sess = session.Session() sess.init_computing() @@ -33,10 +33,14 @@ print(v) print() -table_meta = sess.persistent(computing_table=c_table, namespace="experiment", name=str(uuid.uuid1())) +table_meta = sess.persistent( + computing_table=c_table, namespace="experiment", name=str(uuid.uuid1()) +) storage_session = sess.storage() -s_table = storage_session.get_table(namespace=table_meta.get_namespace(), name=table_meta.get_name()) +s_table = storage_session.get_table( + namespace=table_meta.get_namespace(), name=table_meta.get_name() +) for k, v in s_table.collect(): print(v) print() @@ -44,7 +48,8 @@ t2 = sess.computing.load( table_meta.get_address(), partitions=table_meta.get_partitions(), - schema=table_meta.get_schema()) + schema=table_meta.get_schema(), +) for k, v in t2.collect(): print(v)