diff --git a/python/fate/arch/abc/__init__.py b/python/fate/arch/abc/__init__.py index bd826e72ab..10e18c7a65 100644 --- a/python/fate/arch/abc/__init__.py +++ b/python/fate/arch/abc/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._address import AddressABC from ._components import ComponentMeta, Components from ._computing import CSessionABC, CTableABC diff --git a/python/fate/arch/abc/_address.py b/python/fate/arch/abc/_address.py index 43c123a56c..6dc0a5f6e2 100644 --- a/python/fate/arch/abc/_address.py +++ b/python/fate/arch/abc/_address.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc diff --git a/python/fate/arch/abc/_components.py b/python/fate/arch/abc/_components.py index 3736d2179a..ffe015e695 100644 --- a/python/fate/arch/abc/_components.py +++ b/python/fate/arch/abc/_components.py @@ -6,15 +6,12 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 - # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# - import typing from abc import ABCMeta diff --git a/python/fate/arch/abc/_federation.py b/python/fate/arch/abc/_federation.py index b0bf22eee0..4fefba7f74 100644 --- a/python/fate/arch/abc/_federation.py +++ b/python/fate/arch/abc/_federation.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc import typing from abc import ABCMeta @@ -19,9 +33,7 @@ def session_id(self) -> str: ... @abc.abstractmethod - def get( - self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC - ) -> typing.List: + def get(self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC) -> typing.List: """ get objects/tables from ``parties`` diff --git a/python/fate/arch/abc/_gc.py b/python/fate/arch/abc/_gc.py index 1dfabed890..50e2ac5da9 100644 --- a/python/fate/arch/abc/_gc.py +++ b/python/fate/arch/abc/_gc.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc diff --git a/python/fate/arch/abc/_path.py b/python/fate/arch/abc/_path.py index 2575487db5..7c3c13673c 100644 --- a/python/fate/arch/abc/_path.py +++ b/python/fate/arch/abc/_path.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABCMeta diff --git a/python/fate/arch/common/__init__.py b/python/fate/arch/common/__init__.py index cbd5efb30d..e19e313a89 100644 --- a/python/fate/arch/common/__init__.py +++ b/python/fate/arch/common/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._types import ( BaseType, CoordinationCommunicationProtocol, diff --git a/python/fate/arch/common/_types.py b/python/fate/arch/common/_types.py index 21f95cf3fe..2071faaba5 100644 --- a/python/fate/arch/common/_types.py +++ b/python/fate/arch/common/_types.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class EngineType(object): COMPUTING = "computing" STORAGE = "storage" diff --git a/python/fate/arch/common/address.py b/python/fate/arch/common/address.py index f7567f0fe3..631e624752 100644 --- a/python/fate/arch/common/address.py +++ b/python/fate/arch/common/address.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ..abc import AddressABC from ..metastore.db_utils import StorageConnector @@ -107,9 +121,7 @@ def __repr__(self): class ApiAddress(AddressBase): - def __init__( - self, method="POST", url=None, header=None, body=None, connector_name=None - ): + def __init__(self, method="POST", url=None, header=None, body=None, connector_name=None): self.method = method self.url = url self.header = header if header else {} diff --git a/python/fate/arch/common/data_utils.py b/python/fate/arch/common/data_utils.py index a5748d2989..ea10869d0b 100644 --- a/python/fate/arch/common/data_utils.py +++ b/python/fate/arch/common/data_utils.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import uuid @@ -9,32 +23,22 @@ def default_output_info(task_id, task_version, output_type): return f"output_{output_type}_{task_id}_{task_version}", uuid.uuid1().hex -def default_input_fs_path( - name, namespace, prefix=None, storage_engine=StorageEngine.HDFS -): +def default_input_fs_path(name, namespace, prefix=None, storage_engine=StorageEngine.HDFS): if storage_engine == StorageEngine.HDFS: - return default_hdfs_path( - data_type="input", name=name, namespace=namespace, prefix=prefix - ) + return default_hdfs_path(data_type="input", name=name, namespace=namespace, prefix=prefix) elif storage_engine == StorageEngine.LOCALFS: return default_localfs_path(data_type="input", name=name, namespace=namespace) -def default_output_fs_path( - name, namespace, prefix=None, storage_engine=StorageEngine.HDFS -): +def default_output_fs_path(name, namespace, prefix=None, storage_engine=StorageEngine.HDFS): if storage_engine == StorageEngine.HDFS: - return default_hdfs_path( - data_type="output", name=name, namespace=namespace, prefix=prefix - ) + return default_hdfs_path(data_type="output", name=name, namespace=namespace, prefix=prefix) elif storage_engine == StorageEngine.LOCALFS: return default_localfs_path(data_type="output", name=name, namespace=namespace) def default_localfs_path(name, namespace, data_type): - return os.path.join( - get_project_base_directory(), "localfs", data_type, namespace, name - ) + return os.path.join(get_project_base_directory(), "localfs", data_type, namespace, name) def default_hdfs_path(data_type, name, namespace, prefix=None): diff --git a/python/fate/arch/computing/__init__.py b/python/fate/arch/computing/__init__.py index ac0b30f9f9..44fd114dc2 100644 --- a/python/fate/arch/computing/__init__.py +++ b/python/fate/arch/computing/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The Eggroll Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# - from ._type import ComputingEngine from ._util import is_table diff --git a/python/fate/arch/computing/_type.py b/python/fate/arch/computing/_type.py index 16d48ac831..06e29c62c8 100644 --- a/python/fate/arch/computing/_type.py +++ b/python/fate/arch/computing/_type.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The Eggroll Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# class ComputingEngine(object): diff --git a/python/fate/arch/computing/_util.py b/python/fate/arch/computing/_util.py index f315073b0a..40d0b51fa4 100644 --- a/python/fate/arch/computing/_util.py +++ b/python/fate/arch/computing/_util.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The Eggroll Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# from ..abc import CTableABC diff --git a/python/fate/arch/computing/standalone/_csession.py b/python/fate/arch/computing/standalone/_csession.py index a06943f5de..8bad05e770 100644 --- a/python/fate/arch/computing/standalone/_csession.py +++ b/python/fate/arch/computing/standalone/_csession.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The Eggroll Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import logging from collections.abc import Iterable diff --git a/python/fate/arch/computing/standalone/_table.py b/python/fate/arch/computing/standalone/_table.py index 87803327ea..65bc92562a 100644 --- a/python/fate/arch/computing/standalone/_table.py +++ b/python/fate/arch/computing/standalone/_table.py @@ -1,5 +1,5 @@ # -# Copyright 2019 The Eggroll Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import itertools import logging diff --git a/python/fate/arch/context/__init__.py b/python/fate/arch/context/__init__.py index abf970af7f..10f371d63e 100644 --- a/python/fate/arch/context/__init__.py +++ b/python/fate/arch/context/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._context import Context, Namespace __all__ = ["Context", "Namespace"] diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index c11abc8a2f..19db601827 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from fate.interface import CipherKit as CipherKitInterface from ..tensor._phe import PHECipher diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index d4a2db9c38..59cbe1ddff 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from contextlib import contextmanager from copy import copy from typing import Iterator, List, Optional diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index 26f3e32dd2..d33b58855a 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import io import pickle from typing import Any, List, Optional, TypeVar, Union @@ -54,12 +68,8 @@ def __call__(self, key: str) -> "_KeyedParty": def put(self, *args, **kwargs): if args: - assert len(args) == 2 and isinstance( - args[0], str - ), "invalid position parameter" - assert ( - not kwargs - ), "keywords paramters not allowed when position parameter provided" + assert len(args) == 2 and isinstance(args[0], str), "invalid position parameter" + assert not kwargs, "keywords paramters not allowed when position parameter provided" kvs = [args] else: kvs = kwargs.items() @@ -115,12 +125,8 @@ def get_local_index(self) -> Optional[int]: def put(self, *args, **kwargs): if args: - assert len(args) == 2 and isinstance( - args[0], str - ), "invalid position parameter" - assert ( - not kwargs - ), "keywords paramters not allowed when position parameter provided" + assert len(args) == 2 and isinstance(args[0], str), "invalid position parameter" + assert not kwargs, "keywords paramters not allowed when position parameter provided" kvs = [args] else: kvs = kwargs.items() @@ -156,9 +162,7 @@ def _pull( ) values = [] for party, buffers in zip(parties, raw_values): - values.append( - _TableRmotePersistentUnpickler.pull(buffers, federation, name, tag, party) - ) + values.append(_TableRmotePersistentUnpickler.pull(buffers, federation, name, tag, party)) return values diff --git a/python/fate/arch/context/_metrics.py b/python/fate/arch/context/_metrics.py index e69de29bb2..ae946a49c4 100644 --- a/python/fate/arch/context/_metrics.py +++ b/python/fate/arch/context/_metrics.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/arch/context/_mlmd.py b/python/fate/arch/context/_mlmd.py index 3711ced9ff..8fe46e2388 100644 --- a/python/fate/arch/context/_mlmd.py +++ b/python/fate/arch/context/_mlmd.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from ml_metadata import metadata_store diff --git a/python/fate/arch/context/_namespace.py b/python/fate/arch/context/_namespace.py index 97775dcf78..802a748523 100644 --- a/python/fate/arch/context/_namespace.py +++ b/python/fate/arch/context/_namespace.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from contextlib import contextmanager from typing import Generator, overload @@ -65,9 +79,7 @@ def _state_iterator() -> Generator["Namespace", None, None]: for i in range(start, stop): # the tags in the iteration need to be distinguishable template_formated = f"{prefix_name}iter_{i}" - self._namespace_state = IterationState( - prev_namespace_state.sub_namespace(template_formated) - ) + self._namespace_state = IterationState(prev_namespace_state.sub_namespace(template_formated)) yield self # with context returns iterator of Contexts diff --git a/python/fate/arch/context/_tensor.py b/python/fate/arch/context/_tensor.py index 21a58cdba8..b0ecb4e918 100644 --- a/python/fate/arch/context/_tensor.py +++ b/python/fate/arch/context/_tensor.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch from ..tensor import Tensor as FPTensor @@ -29,11 +43,7 @@ def random_tensor(self, shape, num_partition=1) -> FPTensor: else: parts.append(torch.rand((first_dim_approx, *shape[1:]))) return FPTensor( - FPTensorDistributed( - self.computing.parallelize( - parts, include_key=False, partition=num_partition - ) - ), + FPTensorDistributed(self.computing.parallelize(parts, include_key=False, partition=num_partition)), ) def create_tensor(self, tensor: torch.Tensor) -> "FPTensor": diff --git a/python/fate/arch/context/_utils.py b/python/fate/arch/context/_utils.py index 522ad63bce..1c6cbb1274 100644 --- a/python/fate/arch/context/_utils.py +++ b/python/fate/arch/context/_utils.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. def disable_inner_logs(): from ..common.log import getLogger diff --git a/python/fate/arch/context/io/data/csv.py b/python/fate/arch/context/io/data/csv.py index 858ae71fa1..68cb6afddb 100644 --- a/python/fate/arch/context/io/data/csv.py +++ b/python/fate/arch/context/io/data/csv.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ....unify import URI from .df import Dataframe diff --git a/python/fate/arch/context/io/data/dataframe.py b/python/fate/arch/context/io/data/dataframe.py index e69de29bb2..ae946a49c4 100644 --- a/python/fate/arch/context/io/data/dataframe.py +++ b/python/fate/arch/context/io/data/dataframe.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/arch/context/io/data/df.py b/python/fate/arch/context/io/data/df.py index 601b780681..2a86b33990 100644 --- a/python/fate/arch/context/io/data/df.py +++ b/python/fate/arch/context/io/data/df.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class Dataframe: def __init__(self, frames, num_features, num_samples) -> None: self.data = frames diff --git a/python/fate/arch/context/io/data/eggroll.py b/python/fate/arch/context/io/data/eggroll.py index f96baa2b11..61c769b19c 100644 --- a/python/fate/arch/context/io/data/eggroll.py +++ b/python/fate/arch/context/io/data/eggroll.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from fate.arch.abc import CTableABC from ....unify import EggrollURI @@ -38,9 +52,10 @@ def __init__(self, ctx, uri: EggrollURI, metadata: dict) -> None: self.metadata = metadata def read_dataframe(self): - from .df import Dataframe from fate.arch import dataframe + from .df import Dataframe + table = load_table(self.ctx, self.uri, self.metadata) df = dataframe.deserialize(self.ctx, table) return Dataframe(df, df.shape[1], df.shape[0]) @@ -56,9 +71,10 @@ def __init__(self, ctx, name: str, uri: EggrollURI, metadata: dict) -> None: def read_dataframe(self): import inspect - from .df import Dataframe from fate.arch import dataframe + from .df import Dataframe + table = load_table(self.ctx, self.uri, self.metadata) meta = table.schema.get("meta", {}) diff --git a/python/fate/arch/context/io/data/file.py b/python/fate/arch/context/io/data/file.py index 355c310d70..5a3d120917 100644 --- a/python/fate/arch/context/io/data/file.py +++ b/python/fate/arch/context/io/data/file.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ....unify import FileURI from .df import Dataframe diff --git a/python/fate/arch/context/io/kit.py b/python/fate/arch/context/io/kit.py index 1ca94536f5..531f5c4059 100644 --- a/python/fate/arch/context/io/kit.py +++ b/python/fate/arch/context/io/kit.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Protocol from fate.components import Artifact, DatasetArtifact, MetricArtifact, ModelArtifact diff --git a/python/fate/arch/context/io/metric/file.py b/python/fate/arch/context/io/metric/file.py index 7292a1aadd..a5dd7d6b38 100644 --- a/python/fate/arch/context/io/metric/file.py +++ b/python/fate/arch/context/io/metric/file.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os from typing import Union diff --git a/python/fate/arch/context/io/metric/http.py b/python/fate/arch/context/io/metric/http.py index b1fb54309a..a12fd1162a 100644 --- a/python/fate/arch/context/io/metric/http.py +++ b/python/fate/arch/context/io/metric/http.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from typing import Union diff --git a/python/fate/arch/context/io/model/file.py b/python/fate/arch/context/io/model/file.py index 7b3e238847..386f236372 100644 --- a/python/fate/arch/context/io/model/file.py +++ b/python/fate/arch/context/io/model/file.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from ....unify import URI diff --git a/python/fate/arch/context/io/model/http.py b/python/fate/arch/context/io/model/http.py index 89f5085a4b..903225bd0d 100644 --- a/python/fate/arch/context/io/model/http.py +++ b/python/fate/arch/context/io/model/http.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import requests diff --git a/python/fate/arch/context/metric/__init__.py b/python/fate/arch/context/metric/__init__.py index 1304fa2a33..400736af0a 100644 --- a/python/fate/arch/context/metric/__init__.py +++ b/python/fate/arch/context/metric/__init__.py @@ -1,2 +1,16 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._type import InCompleteMetrics, Metric, Metrics from ._wrap import MetricsWrap diff --git a/python/fate/arch/context/metric/_handler.py b/python/fate/arch/context/metric/_handler.py index 3a2d96e69b..7e59e091a8 100644 --- a/python/fate/arch/context/metric/_handler.py +++ b/python/fate/arch/context/metric/_handler.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Union from fate.interface import MetricsHandler diff --git a/python/fate/arch/context/metric/_incomplte_metrics.py b/python/fate/arch/context/metric/_incomplte_metrics.py index f2516d38b2..08d2b95607 100644 --- a/python/fate/arch/context/metric/_incomplte_metrics.py +++ b/python/fate/arch/context/metric/_incomplte_metrics.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._type import InCompleteMetrics diff --git a/python/fate/arch/context/metric/_metric.py b/python/fate/arch/context/metric/_metric.py index c676b9c188..43c6967c4a 100644 --- a/python/fate/arch/context/metric/_metric.py +++ b/python/fate/arch/context/metric/_metric.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._type import Metric diff --git a/python/fate/arch/context/metric/_metrics.py b/python/fate/arch/context/metric/_metrics.py index db248e64e9..f29c642a1f 100644 --- a/python/fate/arch/context/metric/_metrics.py +++ b/python/fate/arch/context/metric/_metrics.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Dict, Optional from ._type import Metrics diff --git a/python/fate/arch/context/metric/_type.py b/python/fate/arch/context/metric/_type.py index 7dcf7d52f6..6705d35879 100644 --- a/python/fate/arch/context/metric/_type.py +++ b/python/fate/arch/context/metric/_type.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc from typing import Dict, Optional diff --git a/python/fate/arch/context/metric/_wrap.py b/python/fate/arch/context/metric/_wrap.py index 05bdafcf48..00d33c9e5e 100644 --- a/python/fate/arch/context/metric/_wrap.py +++ b/python/fate/arch/context/metric/_wrap.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Optional, Tuple, Union from fate.interface import MetricsHandler @@ -77,4 +91,3 @@ def log_auc(self, name: str, auc: float, step=None, timestamp=None): def log_roc(self, name: str, data: List[Tuple[float, float]]): return self.log_metrics(ROCMetrics(name, data)) - diff --git a/python/fate/arch/dataframe/__init__.py b/python/fate/arch/dataframe/__init__.py index 9f1d49db61..3c766750fd 100644 --- a/python/fate/arch/dataframe/__init__.py +++ b/python/fate/arch/dataframe/__init__.py @@ -1,7 +1,26 @@ -from ._frame_reader import PandasReader, CSVReader, RawTableReader, ImageReader, TorchDataSetReader +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._frame_reader import ( + CSVReader, + ImageReader, + PandasReader, + RawTableReader, + TorchDataSetReader, +) +from .io import build_schema, deserialize, parse_schema, serialize from .utils import DataLoader -from .io import parse_schema, build_schema, serialize, deserialize - __all__ = [ "PandasReader", @@ -12,5 +31,5 @@ "parse_schema", "build_schema", "serialize", - "deserialize" + "deserialize", ] diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index 036fd7e76a..9d7cbc9e6f 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -1,21 +1,30 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import operator + import torch -from .ops import stat_method, arith_method, transform_to_predict_result -from .storage import ValueStore, Index from fate.arch.computing import is_table +from .ops import arith_method, stat_method, transform_to_predict_result +from .storage import Index, ValueStore + # TODO: record data type, support multiple data types class DataFrame(object): - def __init__(self, - ctx, - schema, - index=None, - match_id=None, - values=None, - label=None, - weight=None): + def __init__(self, ctx, schema, index=None, match_id=None, values=None, label=None, weight=None): self._ctx = ctx self._index = index self._match_id = match_id @@ -118,14 +127,9 @@ def __getattr__(self, attr): col_idx = self.schema.header.index(attr) value = self._values[:, col_idx] - schema = dict(sid=self.schema.sid, - header=[attr]) + schema = dict(sid=self.schema.sid, header=[attr]) - return DataFrame( - self._ctx, - schema=schema, - values=value - ) + return DataFrame(self._ctx, schema=schema, values=value) def __getitem__(self, items): indexes = self.__get_index_by_column_names(items) @@ -143,8 +147,9 @@ def __getitem__(self, items): new_schema["header"] = new_header new_schema["anonymous__header"] = new_anonymous_header - return DataFrame(self._ctx, index=self._index, values=ret_tensor, label=self._label, weight=self._weight, - schema=new_schema) + return DataFrame( + self._ctx, index=self._index, values=ret_tensor, label=self._label, weight=self._weight, schema=new_schema + ) def __setitem__(self, keys, item): if not isinstance(item, DataFrame): @@ -165,7 +170,7 @@ def _retrieval_attr(self) -> dict: index=self._index, values=self._values, label=self._label, - weight=self._weight + weight=self._weight, ) def __get_index_by_column_names(self, column_names): @@ -239,17 +244,15 @@ def _retrieval_func(kvs): blocks = blocks.join(agg_indexer, lambda ten, block_mapping: (ten, block_mapping)) blocks = blocks.mapReducePartitions(_retrieval_func, lambda l1, l2: l1 + l2) - blocks = blocks.mapValues(lambda block: sorted(block, key = lambda buf: buf[0])) + blocks = blocks.mapValues(lambda block: sorted(block, key=lambda buf: buf[0])) blocks = blocks.mapValues( - lambda block: torch.tensor([value[1] for value in block], dtype=getattr(torch, dtype))) + lambda block: torch.tensor([value[1] for value in block], dtype=getattr(torch, dtype)) + ) blocks = [block for pid, block in sorted(list(blocks.collect()))] from fate.arch import tensor - return tensor.distributed_tensor( - self._ctx, - blocks, - partitions=len(blocks) - ) + + return tensor.distributed_tensor(self._ctx, blocks, partitions=len(blocks)) weight = _iloc_tensor(self._weight) if self._weight else None label = _iloc_tensor(self._label) if self._label else None @@ -260,13 +263,7 @@ def _retrieval_func(kvs): raise ValueError(f"iloc function dose not support args type={type(indexes)}") return DataFrame( - self._ctx, - self._schema.dict(), - index=index, - match_id=match_id, - label=label, - weight=weight, - values=values + self._ctx, self._schema.dict(), index=index, match_id=match_id, label=label, weight=weight, values=values ) def to_local(self): @@ -282,11 +279,7 @@ def to_local(self): if self._match_id: ret_dict["match_id"] = self._match_id.to_local() - return DataFrame( - self._ctx, - self._schema.dict(), - **ret_dict - ) + return DataFrame(self._ctx, self._schema.dict(), **ret_dict) @property def is_local(self): @@ -301,22 +294,16 @@ def is_local(self): return False - def transform_to_predict_result(self, predict_score, data_type="train", task_type="binary", - classes=None, threshold=0.5): - """ - """ + def transform_to_predict_result( + self, predict_score, data_type="train", task_type="binary", classes=None, threshold=0.5 + ): + """ """ - ret, header = transform_to_predict_result(self._ctx, - predict_score, - data_type=data_type, - task_type=task_type, - classes=classes, - threshold=threshold) + ret, header = transform_to_predict_result( + self._ctx, predict_score, data_type=data_type, task_type=task_type, classes=classes, threshold=threshold + ) - transform_schema = { - "header": header, - "sid": self._schema.sid - } + transform_schema = {"header": header, "sid": self._schema.sid} if self._schema.match_id_name: transform_schema["match_id_name"] = self._schema.match_id_name @@ -329,7 +316,7 @@ def transform_to_predict_result(self, predict_score, data_type="train", task_typ match_id=self._match_id, label=self.label, values=ValueStore(self._ctx, ret, header), - schema=transform_schema + schema=transform_schema, ) def serialize(self): @@ -366,8 +353,9 @@ def __iter__(self): class Schema(object): - def __init__(self, sid=None, match_id_name=None, weight_name=None, - label_name=None, header=None, anonymous_header=None): + def __init__( + self, sid=None, match_id_name=None, weight_name=None, label_name=None, header=None, anonymous_header=None + ): self._sid = sid self._match_id_name = match_id_name self._weight_name = weight_name @@ -400,9 +388,7 @@ def anonymous_header(self): return self._anonymous_header def dict(self): - schema = dict( - sid=self._sid - ) + schema = dict(sid=self._sid) if self._header: schema["header"] = self._header diff --git a/python/fate/arch/dataframe/_frame_reader.py b/python/fate/arch/dataframe/_frame_reader.py index c0e29614b1..c547dec007 100644 --- a/python/fate/arch/dataframe/_frame_reader.py +++ b/python/fate/arch/dataframe/_frame_reader.py @@ -1,23 +1,39 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools import typing import numpy as np +import pandas as pd import torch from fate.arch import tensor -import pandas as pd -from .storage import Index from ._dataframe import DataFrame +from .storage import Index class RawTableReader(object): - def __init__(self, - delimiter: str = ",", - label_name: typing.Union[None, str] = None, - label_type: str = "int", - weight_name: typing.Union[None, str] = None, - dtype: str = "float32", - input_format: str = "dense"): + def __init__( + self, + delimiter: str = ",", + label_name: typing.Union[None, str] = None, + label_type: str = "int", + weight_name: typing.Union[None, str] = None, + dtype: str = "float32", + input_format: str = "dense", + ): self._delimiter = delimiter self._label_name = label_name self._label_type = label_type @@ -49,10 +65,12 @@ def _dense_format_to_frame(self, ctx, table): header_indexes.remove(label_idx) label_type = getattr(torch, self._label_type) label_table = table.mapValues(lambda value: [label_type(value[label_idx])]) - data_dict["label"] = _convert_to_tensor(ctx, - label_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._label_type)) + data_dict["label"] = _convert_to_tensor( + ctx, + label_table, + block_partition_mapping=_block_partition_mapping, + dtype=getattr(torch, self._label_type), + ) schema["label_name"] = self._label_name if self._weight_name: @@ -63,53 +81,51 @@ def _dense_format_to_frame(self, ctx, table): header.remove(self._weight_name) header_indexes.remove(weight_idx) weight_table = table.mapValues(lambda value: [value[weight_idx]]) - data_dict["weight"] = _convert_to_tensor(ctx, - weight_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, "float64")) + data_dict["weight"] = _convert_to_tensor( + ctx, weight_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, "float64") + ) schema["weight_name"] = self._weight_name if header_indexes: value_table = table.mapValues(lambda value: np.array(value)[header_indexes].astype(self._dtype).tolist()) - data_dict["values"] = _convert_to_tensor(ctx, - value_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._dtype)) + data_dict["values"] = _convert_to_tensor( + ctx, value_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, self._dtype) + ) schema["header"] = header - data_dict["index"] = _convert_to_index(ctx, - index_table, - block_partition_mapping=_block_partition_mapping, - global_ranks=_global_ranks) + data_dict["index"] = _convert_to_index( + ctx, index_table, block_partition_mapping=_block_partition_mapping, global_ranks=_global_ranks + ) - return DataFrame(ctx=ctx, - schema=schema, - **data_dict) + return DataFrame(ctx=ctx, schema=schema, **data_dict) class ImageReader(object): """ Image Reader now support convert image to a 3D tensor, dtype=torch.float64 """ - def __init__(self, - mode="L", - ): + def __init__( + self, + mode="L", + ): ... class CSVReader(object): # TODO: fast data read # TODO: a. support match_id, b. more id type - def __init__(self, - id_name: typing.Union[None, str] = None, - delimiter: str = ",", - label_name: typing.Union[None, str] = None, - label_type: str = "int", - weight_name: typing.Union[None, str] = None, - dtype: str = "float32", - partition: int = 4): + def __init__( + self, + id_name: typing.Union[None, str] = None, + delimiter: str = ",", + label_name: typing.Union[None, str] = None, + label_type: str = "int", + weight_name: typing.Union[None, str] = None, + dtype: str = "float32", + partition: int = 4, + ): self._id_name = id_name self._delimiter = delimiter self._label_name = label_name @@ -122,11 +138,13 @@ def to_frame(self, ctx, path): # TODO: use table put data instead of read all data df = pd.read_csv(path, delimiter=self._delimiter) - return PandasReader(id_name=self._id_name, - label_name=self._label_name, - label_type=self._label_type, - weight_name=self._weight_name, - partition=self._partition).to_frame(ctx, df) + return PandasReader( + id_name=self._id_name, + label_name=self._label_name, + label_type=self._label_type, + weight_name=self._weight_name, + partition=self._partition, + ).to_frame(ctx, df) class HiveReader(object): @@ -143,7 +161,9 @@ class TextReader(object): class TorchDataSetReader(object): # TODO: this is for Torch DataSet Reader, the passing object has attributes __len__ and __get_item__ - def __init__(self, ): + def __init__( + self, + ): ... def to_frame(self, ctx, dataset): @@ -151,13 +171,15 @@ def to_frame(self, ctx, dataset): class PandasReader(object): - def __init__(self, - id_name: typing.Union[None, str] = None, - label_name: str = None, - label_type: str = "int", - weight_name: typing.Union[None, str] = None, - dtype: str = "float32", - partition: int = 4): + def __init__( + self, + id_name: typing.Union[None, str] = None, + label_name: str = None, + label_type: str = "int", + weight_name: typing.Union[None, str] = None, + dtype: str = "float32", + partition: int = 4, + ): self._id_name = id_name self._label_name = label_name self._label_type = label_type @@ -165,7 +187,7 @@ def __init__(self, self._dtype = dtype self._partition = partition - def to_frame(self, ctx, df: 'pd.DataFrame'): + def to_frame(self, ctx, df: "pd.DataFrame"): schema = dict() if not self._id_name: self._id_name = df.columns[0] @@ -177,9 +199,7 @@ def to_frame(self, ctx, df: 'pd.DataFrame'): id_list = df.index.tolist() index_table = ctx.computing.parallelize( - zip(id_list, range(df.shape[0])), - include_key=True, - partition=self._partition + zip(id_list, range(df.shape[0])), include_key=True, partition=self._partition ) index_table, _block_partition_mapping, _global_ranks = _convert_to_order_indexes(index_table) @@ -188,54 +208,45 @@ def to_frame(self, ctx, df: 'pd.DataFrame'): if self._label_name: label_list = [[label] for label in df[self._label_name].tolist()] label_table = ctx.computing.parallelize( - zip(id_list, label_list), - include_key=True, - partition=self._partition + zip(id_list, label_list), include_key=True, partition=self._partition + ) + data_dict["label"] = _convert_to_tensor( + ctx, + label_table, + block_partition_mapping=_block_partition_mapping, + dtype=getattr(torch, self._label_type), ) - data_dict["label"] = _convert_to_tensor(ctx, - label_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._label_type)) df = df.drop(columns=self._label_name) schema["label_name"] = self._label_name if self._weight_name: weight_list = df[self._weight_name].tolist() weight_table = ctx.computing.parallelize( - zip(id_list, weight_list), - include_key=True, - partition=self._partition + zip(id_list, weight_list), include_key=True, partition=self._partition + ) + data_dict["weight"] = _convert_to_tensor( + ctx, weight_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, "float64") ) - data_dict["weight"] = _convert_to_tensor(ctx, - weight_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, "float64")) df = df.drop(columns=self._weight_name) schema["weight_name"] = self._weight_name if df.shape[1]: value_table = ctx.computing.parallelize( - zip(id_list, df.values), - include_key=True, - partition=self._partition + zip(id_list, df.values), include_key=True, partition=self._partition + ) + data_dict["values"] = _convert_to_tensor( + ctx, value_table, block_partition_mapping=_block_partition_mapping, dtype=getattr(torch, self._dtype) ) - data_dict["values"] = _convert_to_tensor(ctx, - value_table, - block_partition_mapping=_block_partition_mapping, - dtype=getattr(torch, self._dtype)) schema["header"] = df.columns.to_list() - data_dict["index"] = _convert_to_index(ctx, - index_table, - block_partition_mapping=_block_partition_mapping, - global_ranks=_global_ranks) + data_dict["index"] = _convert_to_index( + ctx, index_table, block_partition_mapping=_block_partition_mapping, global_ranks=_global_ranks + ) schema["sid"] = self._id_name - return DataFrame(ctx=ctx, - schema=schema, - **data_dict) + return DataFrame(ctx=ctx, schema=schema, **data_dict) def _convert_to_order_indexes(table): @@ -261,46 +272,31 @@ def _order_indexes(kvs, rank_dict: dict = None): block_partition_mapping = dict() global_ranks = [] for blk_key, blk_size in block_summary.items(): - block_partition_mapping[blk_key] = dict(start_index=start_index, - end_index=start_index + blk_size - 1, - block_id=block_id) + block_partition_mapping[blk_key] = dict( + start_index=start_index, end_index=start_index + blk_size - 1, block_id=block_id + ) global_ranks.append(block_partition_mapping[blk_key]) start_index += blk_size block_id += 1 - order_func = functools.partial(_order_indexes, - rank_dict=block_partition_mapping) - order_table = table.mapPartitions( - order_func, - use_previous_behavior=False - ) + order_func = functools.partial(_order_indexes, rank_dict=block_partition_mapping) + order_table = table.mapPartitions(order_func, use_previous_behavior=False) return order_table, block_partition_mapping, global_ranks def _convert_to_index(ctx, table, block_partition_mapping, global_ranks): - return Index( - ctx, - table, - block_partition_mapping=block_partition_mapping, - global_ranks=global_ranks - ) + return Index(ctx, table, block_partition_mapping=block_partition_mapping, global_ranks=global_ranks) def _convert_to_tensor(ctx, table, block_partition_mapping, dtype): # TODO: in mini-demo stage, distributed tensor only accept list, in future, replace this with distributed table. - convert_func = functools.partial(_convert_block, - block_partition_mapping=block_partition_mapping, - dtype=dtype) + convert_func = functools.partial(_convert_block, block_partition_mapping=block_partition_mapping, dtype=dtype) blocks_with_id = list(table.mapPartitions(convert_func, use_previous_behavior=False).collect()) blocks = [block_with_id[1] for block_with_id in sorted(blocks_with_id)] - return tensor.distributed_tensor( - ctx, - blocks, - partitions=len(blocks) - ) + return tensor.distributed_tensor(ctx, blocks, partitions=len(blocks)) def _convert_block(kvs, block_partition_mapping, dtype, convert_type="tensor"): diff --git a/python/fate/arch/dataframe/io/__init__.py b/python/fate/arch/dataframe/io/__init__.py index 0f856b39a0..2cf3a0c510 100644 --- a/python/fate/arch/dataframe/io/__init__.py +++ b/python/fate/arch/dataframe/io/__init__.py @@ -1,9 +1,18 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._json_schema import build_schema, parse_schema -from ._json_serialization import serialize, deserialize +from ._json_serialization import deserialize, serialize -__all__ = [ - "build_schema", - "parse_schema", - "serialize", - "deserialize" -] \ No newline at end of file +__all__ = ["build_schema", "parse_schema", "serialize", "deserialize"] diff --git a/python/fate/arch/dataframe/io/_json_schema.py b/python/fate/arch/dataframe/io/_json_schema.py index 54950380fa..b65c35601b 100644 --- a/python/fate/arch/dataframe/io/_json_schema.py +++ b/python/fate/arch/dataframe/io/_json_schema.py @@ -1,6 +1,20 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pandas as pd -from ..storage import ValueStore +from ..storage import ValueStore FRAME_SCHEME = "fate.dataframe" @@ -11,42 +25,18 @@ def build_schema(data): """ index, match_id, label, weight, values """ - fields.append( - dict( - type="str", - name=schema.sid, - property="index" - ) - ) + fields.append(dict(type="str", name=schema.sid, property="index")) if schema.match_id_name is not None: - fields.append( - dict( - type="str", - name=schema.match_id_name, - property="match_id" - ) - ) + fields.append(dict(type="str", name=schema.match_id_name, property="match_id")) if schema.label_name is not None: label = data.label - fields.append( - dict( - type=label.dtype.name, - name=schema.label_name, - property="label" - ) - ) + fields.append(dict(type=label.dtype.name, name=schema.label_name, property="label")) if schema.weight_name is not None: weight = data.weight - fields.append( - dict( - type=weight.dtype.name, - name=schema["weight_name"], - property="weight" - ) - ) + fields.append(dict(type=weight.dtype.name, name=schema["weight_name"], property="weight")) if schema.header is not None: values = data.values @@ -59,20 +49,13 @@ def build_schema(data): type=dtypes[col_name].name, name=col_name, property="value", - source="fate.dataframe.value_store" + source="fate.dataframe.value_store", ) ) else: for col_name in columns: - fields.append( - dict( - type=values.dtype.name, - name=col_name, - property="value", - source="fate.arch.tensor" - ) - ) + fields.append(dict(type=values.dtype.name, name=col_name, property="value", source="fate.arch.tensor")) built_schema = dict() built_schema["fields"] = fields @@ -93,35 +76,26 @@ def parse_schema(schema): for idx, field in enumerate(fields): if field["property"] == "index": recovery_schema["sid"] = field["name"] - column_info["index"] = dict(start_idx=idx, - end_idx=idx, - type=field["type"]) + column_info["index"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) elif field["property"] == "match_id": recovery_schema["match_id_name"] = field["name"] - column_info["match_id"] = dict(start_idx=idx, - end_idx=idx, - type=field["type"]) + column_info["match_id"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) elif field["property"] == "label": recovery_schema["label_name"] = field["name"] - column_info["label"] = dict(start_idx=idx, - end_idx=idx, - type=field["type"]) + column_info["label"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) elif field["property"] == "weight": recovery_schema["weight_name"] = field["name"] - column_info["weight"] = dict(start_idx=idx, - end_idx=idx, - type=field["type"]) + column_info["weight"] = dict(start_idx=idx, end_idx=idx, type=field["type"]) elif field["property"] == "value": header = [field["name"] for field in fields[idx:]] recovery_schema["header"] = header - column_info["values"] = dict(start_idx=idx, - end_idx=idx + len(header) - 1, - type=field["type"], - source=field["source"]) + column_info["values"] = dict( + start_idx=idx, end_idx=idx + len(header) - 1, type=field["type"], source=field["source"] + ) break return recovery_schema, schema["global_ranks"], schema["block_partition_mapping"], column_info diff --git a/python/fate/arch/dataframe/io/_json_serialization.py b/python/fate/arch/dataframe/io/_json_serialization.py index 65636bdde9..8c877d0e8c 100644 --- a/python/fate/arch/dataframe/io/_json_serialization.py +++ b/python/fate/arch/dataframe/io/_json_serialization.py @@ -1,13 +1,29 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools + import numpy as np import pandas as pd import torch -from ._json_schema import build_schema, parse_schema -from .._dataframe import DataFrame -from ..storage import Index, ValueStore from fate.arch import tensor from fate.arch.context.io.data import df +from .._dataframe import DataFrame +from ..storage import Index, ValueStore +from ._json_schema import build_schema, parse_schema + def _serialize_local(ctx, data): """ @@ -40,17 +56,14 @@ def _serialize_local(ctx, data): value_concat = tensor_concat if value_concat is not None: - tensor_concat = ctx.computing.parallelize( - [value_concat.tolist()], - include_key=False, - partition=1 - ) + tensor_concat = ctx.computing.parallelize([value_concat.tolist()], include_key=False, partition=1) """ data only has index """ if tensor_concat is None: serialize_data = data.index.mapValues(lambda pd_index: pd_index.tolist()) else: + def _flatten(index: pd.Index, t: list): index = index.tolist() # t = t.tolist() @@ -63,8 +76,7 @@ def _flatten(index: pd.Index, t: list): serialize_data = data.index.to_local().values.join(tensor_concat, _flatten) serialize_data.schema = schema - data_dict = dict(data=list(serialize_data.collect()), - schema=schema) + data_dict = dict(data=list(serialize_data.collect()), schema=schema) return data_dict @@ -95,15 +107,15 @@ def _serialize_distributed(ctx, data): if isinstance(data.values, ValueStore): value_concat = data.values.values if tensor_concat is not None: - value_concat = tensor_concat.join(value_concat, - lambda t1, t2: np.concatenate( - [t1.to_local().data.numpy(), t2.to_numpy()], axis=-1)) + value_concat = tensor_concat.join( + value_concat, lambda t1, t2: np.concatenate([t1.to_local().data.numpy(), t2.to_numpy()], axis=-1) + ) else: value_concat = data.values.storage.blocks.mapValues(lambda t: t.to_local().data) if tensor_concat is not None: - value_concat = tensor_concat.join(value_concat, - lambda t1, t2: np.concatenate( - [t1.to_local().data.numpy(), t2.numpy()], axis=-1)) + value_concat = tensor_concat.join( + value_concat, lambda t1, t2: np.concatenate([t1.to_local().data.numpy(), t2.numpy()], axis=-1) + ) else: value_concat = tensor_concat @@ -114,11 +126,12 @@ def _serialize_distributed(ctx, data): index = Index.aggregate(data.index.values) if tensor_concat is None: - """ - data only has index + """ + data only has index """ serialize_data = index else: + def _flatten(index: list, t): flatten_ret = [] for (_id, block_index), _t in zip(index, t): @@ -161,7 +174,7 @@ def _recovery_tensor(value, tensor_info=None): ret_tensor = [] for v in value: - ret_tensor.append(v[start_index: end_index + 1]) + ret_tensor.append(v[start_index : end_index + 1]) return torch.tensor(ret_tensor, dtype=getattr(torch, dtype)) @@ -178,38 +191,29 @@ def _recovery_distributed_value_store(value, value_info, header): return df def _to_distributed_tensor(tensor_list): - return tensor.distributed_tensor( - ctx, tensor_list, partitions=len(tensor_list) - ) + return tensor.distributed_tensor(ctx, tensor_list, partitions=len(tensor_list)) ret_dict = dict() - ret_dict["index"] = Index(ctx=ctx, - distributed_index=data.mapPartitions(_recovery_index, use_previous_behavior=False), - block_partition_mapping=block_partition_mapping, - global_ranks=global_ranks) + ret_dict["index"] = Index( + ctx=ctx, + distributed_index=data.mapPartitions(_recovery_index, use_previous_behavior=False), + block_partition_mapping=block_partition_mapping, + global_ranks=global_ranks, + ) tensor_keywords = ["weight", "label", "values"] for keyword in tensor_keywords: if keyword in column_info: if keyword == "values" and column_info["values"]["source"] == "fate.dataframe.value_store": continue - _recovery_func = functools.partial( - _recovery_tensor, - tensor_info=column_info[keyword] - ) + _recovery_func = functools.partial(_recovery_tensor, tensor_info=column_info[keyword]) tensors = [tensor for key, tensor in sorted(list(data.mapValues(_recovery_func).collect()))] ret_dict[keyword] = _to_distributed_tensor(tensors) if "values" in column_info and column_info["values"]["source"] == "fate.dataframe.value_store": _recovery_df_func = functools.partial( - _recovery_distributed_value_store, - value_info=column_info["values"], - header=recovery_schema["header"] - ) - ret_dict["values"] = ValueStore( - ctx, - data.mapValues(_recovery_df_func), - recovery_schema["header"] + _recovery_distributed_value_store, value_info=column_info["values"], header=recovery_schema["header"] ) + ret_dict["values"] = ValueStore(ctx, data.mapValues(_recovery_df_func), recovery_schema["header"]) return DataFrame(ctx, recovery_schema, **ret_dict) diff --git a/python/fate/arch/dataframe/ops/__init__.py b/python/fate/arch/dataframe/ops/__init__.py index 93b0ce48ec..063d19301e 100644 --- a/python/fate/arch/dataframe/ops/__init__.py +++ b/python/fate/arch/dataframe/ops/__init__.py @@ -1,10 +1,19 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._arithmetic import arith_method from ._predict_result_transformaton import transform_to_predict_result from ._stat import stat_method - -__all__ = [ - "arith_method", - "transform_to_predict_result", - "stat_method" -] \ No newline at end of file +__all__ = ["arith_method", "transform_to_predict_result", "stat_method"] diff --git a/python/fate/arch/dataframe/ops/_arithmetic.py b/python/fate/arch/dataframe/ops/_arithmetic.py index d4609ab6c7..6476621ef3 100644 --- a/python/fate/arch/dataframe/ops/_arithmetic.py +++ b/python/fate/arch/dataframe/ops/_arithmetic.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import numpy as np import pandas as pd import torch @@ -15,11 +29,3 @@ def arith_method(lhs, rhs, op): raise ValueError(f"{op.__name__} between DataFrame and {type(rhs)} is not supported") return op(lhs, rhs) - - - - - - - - diff --git a/python/fate/arch/dataframe/ops/_predict_result_transformaton.py b/python/fate/arch/dataframe/ops/_predict_result_transformaton.py index e1a6f44a39..a70f40ca6c 100644 --- a/python/fate/arch/dataframe/ops/_predict_result_transformaton.py +++ b/python/fate/arch/dataframe/ops/_predict_result_transformaton.py @@ -1,12 +1,27 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools + import pandas as pd import torch -def transform_to_predict_result(ctx, predict_score, data_type="train", task_type="binary", classes=None, - threshold=0.5): - """ - """ +def transform_to_predict_result( + ctx, predict_score, data_type="train", task_type="binary", classes=None, threshold=0.5 +): + """ """ transform_header = _predict_header_transform(task_type) if task_type == "regression": ... @@ -15,17 +30,15 @@ def transform_to_predict_result(ctx, predict_score, data_type="train", task_type predict_score = predict_score.storage.blocks.mapValues(lambda t: t.to_local().data) else: predict_score_local = predict_score.storage.data - predict_score = ctx.computing.parallelize( - [predict_score_local], - include_key=False, - partition=1 - ) - - to_predict_result_func = functools.partial(_predict_score_to_binary_result, - header=transform_header, - threshold=threshold, - classes=classes, - data_type=data_type) + predict_score = ctx.computing.parallelize([predict_score_local], include_key=False, partition=1) + + to_predict_result_func = functools.partial( + _predict_score_to_binary_result, + header=transform_header, + threshold=threshold, + classes=classes, + data_type=data_type, + ) predict_result = predict_score.mapValues(to_predict_result_func) return predict_result, transform_header @@ -36,10 +49,7 @@ def transform_to_predict_result(ctx, predict_score, data_type="train", task_type def _predict_header_transform(task_type): if task_type in ["regression", "binary", "multi"]: - return ["predict_result", - "predict_score", - "predict_detail", - "type"] + return ["predict_result", "predict_score", "predict_detail", "type"] elif task_type == "cluster": ... else: diff --git a/python/fate/arch/dataframe/ops/_stat.py b/python/fate/arch/dataframe/ops/_stat.py index 5025ad0a65..0615519766 100644 --- a/python/fate/arch/dataframe/ops/_stat.py +++ b/python/fate/arch/dataframe/ops/_stat.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pandas as pd diff --git a/python/fate/arch/dataframe/storage/__init__.py b/python/fate/arch/dataframe/storage/__init__.py index 0e8bf10b87..8b1973877d 100644 --- a/python/fate/arch/dataframe/storage/__init__.py +++ b/python/fate/arch/dataframe/storage/__init__.py @@ -1,2 +1,16 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._index import Index -from ._value_store import ValueStore \ No newline at end of file +from ._value_store import ValueStore diff --git a/python/fate/arch/dataframe/storage/_index.py b/python/fate/arch/dataframe/storage/_index.py index b34961e9d4..bbece4c3ff 100644 --- a/python/fate/arch/dataframe/storage/_index.py +++ b/python/fate/arch/dataframe/storage/_index.py @@ -1,5 +1,20 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import functools + import pandas as pd from fate.arch.computing import is_table @@ -36,7 +51,7 @@ def __len__(self): return self.count() def tolist(self): - indexes_with_partition_id = sorted(self._index_table.collect(), key = lambda kv: kv[1]) + indexes_with_partition_id = sorted(self._index_table.collect(), key=lambda kv: kv[1]) id_list = [k for k, v in indexes_with_partition_id] return id_list @@ -46,24 +61,16 @@ def to_local(self): index_table: id, (partition_id, block_index) """ index_table = self._index_table.mapValues( - lambda order_tuple: (0, self._global_ranks[order_tuple[0]]["start_index"] + order_tuple[1])) + lambda order_tuple: (0, self._global_ranks[order_tuple[0]]["start_index"] + order_tuple[1]) + ) - global_ranks = [ - dict(start_index=0, - end_index=self.count(), - block_id=0) - ] + global_ranks = [dict(start_index=0, end_index=self.count(), block_id=0)] block_partition_mapping = copy.deepcopy(self._block_partition_mapping) for block_id in self._block_partition_mapping: if block_id != 0: block_partition_mapping.pop(block_id) - return Index( - self._ctx, - index_table, - block_partition_mapping, - global_ranks - ) + return Index(self._ctx, index_table, block_partition_mapping, global_ranks) def __getitem__(self, items): if isinstance(items, int): @@ -92,30 +99,23 @@ def _flat_partition(k, values, ranks=None): return _flat_ret - _flat_func = functools.partial(_flat_partition, - ranks=global_ranks) + _flat_func = functools.partial(_flat_partition, ranks=global_ranks) index_table = agg_table.flatMap(_flat_func) - return Index( - self._ctx, - index_table, - block_partition_mapping, - global_ranks - ) + return Index(self._ctx, index_table, block_partition_mapping, global_ranks) def get_indexer(self, ids, with_partition_id=True): if isinstance(ids, list): + def _filter_id(key, value, ids_set=None): return key in ids_set - filter_func = functools.partial(_filter_id, - ids_set=set(ids)) + filter_func = functools.partial(_filter_id, ids_set=set(ids)) indexer = self._index_table.filter(filter_func) indexer = indexer.mapValues(lambda v: [v, v]) elif is_table(ids): - """ - """ + """ """ if with_partition_id: indexer = self._index_table.join(ids, lambda v1, v2: [v1, v2]) else: @@ -131,9 +131,7 @@ def _filter(k, v, index_set=None, global_ranks=None): partition_id, block_index = v return global_ranks[partition_id]["start_index"] + block_index in index_set - filter_func = functools.partial(_filter, - index_set=set(indexes), - global_ranks=self._global_ranks) + filter_func = functools.partial(_filter, index_set=set(indexes), global_ranks=self._global_ranks) indexer = self._index_table.filter(filter_func, use_previous_behavior=False) indexer = indexer.mapValues(lambda v: [v, v]) return indexer @@ -143,6 +141,7 @@ def aggregate(cls, table): """ agg_table: key=partition_id, value=(id, block_index), block_index may be not continuous """ + def _aggregate_ids(kvs): aggregate_ret = dict() @@ -157,7 +156,8 @@ def _aggregate_ids(kvs): agg_table = table.mapReducePartitions(_aggregate_ids, lambda l1, l2: l1 + l2) agg_table = agg_table.mapValues( - lambda id_list: sorted(id_list, key = lambda block_index_with_key: block_index_with_key[1])) + lambda id_list: sorted(id_list, key=lambda block_index_with_key: block_index_with_key[1]) + ) return agg_table @@ -168,6 +168,7 @@ def aggregate_indexer(cls, indexer): => key=old_partition_id, value=[old_block_index, (new_partition_id, new_block_index)] """ + def _aggregate(kvs): aggregate_ret = dict() for k, values in kvs: @@ -196,33 +197,15 @@ def regenerate_global_ranks(cls, agg_table, old_global_ranks): if global_ranks and global_ranks[-1]["block_id"] + 1 != block_id: last_bid = global_ranks[-1]["block_id"] for bid in range(last_bid + 1, block_id): - global_ranks.append( - dict( - start_index=idx, - end_index=idx-1, - block_id=bid - ) - ) - - global_ranks.append( - dict( - start_index=idx, - end_index=idx + block_count - 1, - block_id=block_id - ) - ) + global_ranks.append(dict(start_index=idx, end_index=idx - 1, block_id=bid)) + + global_ranks.append(dict(start_index=idx, end_index=idx + block_count - 1, block_id=block_id)) idx += block_count if len(global_ranks) < len(old_global_ranks): last_bid = len(global_ranks) for bid in range(last_bid, len(old_global_ranks)): - global_ranks.append( - dict( - start_index=idx, - end_index=idx-1, - block_id=bid - ) - ) + global_ranks.append(dict(start_index=idx, end_index=idx - 1, block_id=bid)) return global_ranks diff --git a/python/fate/arch/dataframe/storage/_value_store.py b/python/fate/arch/dataframe/storage/_value_store.py index 3039ca331e..dfbbb5d513 100644 --- a/python/fate/arch/dataframe/storage/_value_store.py +++ b/python/fate/arch/dataframe/storage/_value_store.py @@ -1,4 +1,19 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools + import pandas as pd @@ -19,27 +34,15 @@ def to_local(self, keep_table=False): if not keep_table: return concat_frame else: - table = self._ctx.computing.parallelize( - [(0, concat_frame)], - include_key=True, - partition=1 - ) - - return ValueStore( - self._ctx, - table, - self._header - ) + table = self._ctx.computing.parallelize([(0, concat_frame)], include_key=True, partition=1) + + return ValueStore(self._ctx, table, self._header) def __getattr__(self, attr): if attr not in self._header: raise ValueError(f"ValueStore does not has attribute: {attr}") - return ValueStore( - self._ctx, - self._data.mapValues(lambda df: df[attr]), - [attr] - ) + return ValueStore(self._ctx, self._data.mapValues(lambda df: df[attr]), [attr]) def tolist(self): return self.to_local().tolist() diff --git a/python/fate/arch/dataframe/utils/__init__.py b/python/fate/arch/dataframe/utils/__init__.py index 0d6f6737db..9b670c6d5b 100644 --- a/python/fate/arch/dataframe/utils/__init__.py +++ b/python/fate/arch/dataframe/utils/__init__.py @@ -1 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._dataloader import DataLoader diff --git a/python/fate/arch/dataframe/utils/_dataloader.py b/python/fate/arch/dataframe/utils/_dataloader.py index 232dde7cec..4243109800 100644 --- a/python/fate/arch/dataframe/utils/_dataloader.py +++ b/python/fate/arch/dataframe/utils/_dataloader.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import random import numpy as np @@ -6,6 +20,7 @@ from fate.arch.context.io.data import df from fate.arch.dataframe import PandasReader, TorchDataSetReader + class DataLoader(object): def __init__( self, diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index 0f6a322059..a20d642676 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._type import FederationDataType, FederationEngine __all__ = ["FederationEngine", "FederationDataType"] diff --git a/python/fate/arch/federation/_federation.py b/python/fate/arch/federation/_federation.py index 6ce18cebae..66331110e0 100644 --- a/python/fate/arch/federation/_federation.py +++ b/python/fate/arch/federation/_federation.py @@ -1,5 +1,5 @@ # -# Copyright 2022 The FATE Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/fate/arch/federation/_nretry.py b/python/fate/arch/federation/_nretry.py index cf6fcdb318..5f5c5e88f3 100644 --- a/python/fate/arch/federation/_nretry.py +++ b/python/fate/arch/federation/_nretry.py @@ -1,5 +1,5 @@ # -# Copyright 2022 The FATE Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/fate/arch/federation/_parties.py b/python/fate/arch/federation/_parties.py index ffe0343c22..62a2908094 100644 --- a/python/fate/arch/federation/_parties.py +++ b/python/fate/arch/federation/_parties.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import enum from typing import List diff --git a/python/fate/arch/federation/osx/__init__.py b/python/fate/arch/federation/osx/__init__.py index 7fd4292e7a..bf5023f76c 100644 --- a/python/fate/arch/federation/osx/__init__.py +++ b/python/fate/arch/federation/osx/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys diff --git a/python/fate/arch/federation/osx/_federation.py b/python/fate/arch/federation/osx/_federation.py index 0b9a4cee8d..0d8db2c907 100644 --- a/python/fate/arch/federation/osx/_federation.py +++ b/python/fate/arch/federation/osx/_federation.py @@ -1,5 +1,5 @@ # -# Copyright 2022 The FATE Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/fate/arch/federation/osx/_mq_channel.py b/python/fate/arch/federation/osx/_mq_channel.py index 1f6abd217d..5cafe45684 100644 --- a/python/fate/arch/federation/osx/_mq_channel.py +++ b/python/fate/arch/federation/osx/_mq_channel.py @@ -1,5 +1,5 @@ # -# Copyright 2022 The FATE Authors. All Rights Reserved. +# Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/fate/arch/federation/osx/osx_pb2.py b/python/fate/arch/federation/osx/osx_pb2.py index e14dfcc3ec..b62a917a27 100644 --- a/python/fate/arch/federation/osx/osx_pb2.py +++ b/python/fate/arch/federation/osx/osx_pb2.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: osx.proto diff --git a/python/fate/arch/federation/osx/osx_pb2_grpc.py b/python/fate/arch/federation/osx/osx_pb2_grpc.py index 95b03d66b8..b88a9f479b 100644 --- a/python/fate/arch/federation/osx/osx_pb2_grpc.py +++ b/python/fate/arch/federation/osx/osx_pb2_grpc.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/python/fate/arch/federation/pulsar/__init__.py b/python/fate/arch/federation/pulsar/__init__.py index cc2905d2a1..6e3d4a1a26 100644 --- a/python/fate/arch/federation/pulsar/__init__.py +++ b/python/fate/arch/federation/pulsar/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._federation import MQ, PulsarFederation, PulsarManager __all__ = ["PulsarFederation", "MQ", "PulsarManager"] diff --git a/python/fate/arch/federation/rabbitmq/__init__.py b/python/fate/arch/federation/rabbitmq/__init__.py index 6dd16d7d83..e265c8bdaa 100644 --- a/python/fate/arch/federation/rabbitmq/__init__.py +++ b/python/fate/arch/federation/rabbitmq/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._federation import RabbitmqFederation __all__ = ["RabbitmqFederation"] diff --git a/python/fate/arch/federation/standalone/_federation.py b/python/fate/arch/federation/standalone/_federation.py index df4c3e654c..f927455b93 100644 --- a/python/fate/arch/federation/standalone/_federation.py +++ b/python/fate/arch/federation/standalone/_federation.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from typing import List, Tuple diff --git a/python/fate/arch/metastore/db_utils.py b/python/fate/arch/metastore/db_utils.py index 558aa4a716..60c0f736df 100644 --- a/python/fate/arch/metastore/db_utils.py +++ b/python/fate/arch/metastore/db_utils.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import operator from ..common.base_utils import current_timestamp @@ -18,9 +32,7 @@ def create_or_update(self): "f_connector_info": self.connector_info, "f_create_time": current_timestamp(), } - connector, status = StorageConnectorModel.get_or_create( - f_name=self.name, defaults=defaults - ) + connector, status = StorageConnectorModel.get_or_create(f_name=self.name, defaults=defaults) if status is False: for key in defaults: setattr(connector, key, defaults[key]) diff --git a/python/fate/arch/storage/__init__.py b/python/fate/arch/storage/__init__.py index 7692db3fdc..2f5c7a8015 100644 --- a/python/fate/arch/storage/__init__.py +++ b/python/fate/arch/storage/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._session import StorageSessionBase from ._table import StorageTableBase, StorageTableMeta from ._types import ( diff --git a/python/fate/arch/storage/hive/__init__.py b/python/fate/arch/storage/hive/__init__.py index 39d4b8ca70..2f9f6d92f4 100644 --- a/python/fate/arch/storage/hive/__init__.py +++ b/python/fate/arch/storage/hive/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._session import StorageSession from ._table import StorageTable diff --git a/python/fate/arch/storage/linkis_hive/__init__.py b/python/fate/arch/storage/linkis_hive/__init__.py index 39d4b8ca70..2f9f6d92f4 100644 --- a/python/fate/arch/storage/linkis_hive/__init__.py +++ b/python/fate/arch/storage/linkis_hive/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._session import StorageSession from ._table import StorageTable diff --git a/python/fate/arch/storage/linkis_hive/_settings.py b/python/fate/arch/storage/linkis_hive/_settings.py index b954bc4e9a..9868bcdb85 100644 --- a/python/fate/arch/storage/linkis_hive/_settings.py +++ b/python/fate/arch/storage/linkis_hive/_settings.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # token Token_Code = "" Token_User = "fate" diff --git a/python/fate/arch/storage/localfs/__init__.py b/python/fate/arch/storage/localfs/__init__.py index 39d4b8ca70..2f9f6d92f4 100644 --- a/python/fate/arch/storage/localfs/__init__.py +++ b/python/fate/arch/storage/localfs/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._session import StorageSession from ._table import StorageTable diff --git a/python/fate/arch/tensor/__init__.py b/python/fate/arch/tensor/__init__.py index 46b8a759ad..843b821e7f 100644 --- a/python/fate/arch/tensor/__init__.py +++ b/python/fate/arch/tensor/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._tensor import distributed_tensor, randn, tensor from .ops import * from .types import * diff --git a/python/fate/arch/tensor/_exception.py b/python/fate/arch/tensor/_exception.py index 26e11900ef..13ea8599c7 100644 --- a/python/fate/arch/tensor/_exception.py +++ b/python/fate/arch/tensor/_exception.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class OpsDispatchException(Exception): ... diff --git a/python/fate/arch/tensor/_generate.py b/python/fate/arch/tensor/_generate.py index f3c98fbefa..140ac85043 100644 --- a/python/fate/arch/tensor/_generate.py +++ b/python/fate/arch/tensor/_generate.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. if __name__ == "__main__": import pathlib diff --git a/python/fate/arch/tensor/_phe.py b/python/fate/arch/tensor/_phe.py index aa14db3400..cfb913315e 100644 --- a/python/fate/arch/tensor/_phe.py +++ b/python/fate/arch/tensor/_phe.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from enum import Enum from typing import Tuple diff --git a/python/fate/arch/tensor/_tensor.py b/python/fate/arch/tensor/_tensor.py index 0474d49014..55cfdcf562 100644 --- a/python/fate/arch/tensor/_tensor.py +++ b/python/fate/arch/tensor/_tensor.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Union import torch diff --git a/python/fate/arch/tensor/ops/__init__.py b/python/fate/arch/tensor/ops/__init__.py index a8d2664e1e..aa8e04f264 100644 --- a/python/fate/arch/tensor/ops/__init__.py +++ b/python/fate/arch/tensor/ops/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._agg_ops import * from ._binary_ops import * from ._matmul_ops import * diff --git a/python/fate/arch/tensor/ops/_agg_ops.py b/python/fate/arch/tensor/ops/_agg_ops.py index 89f198a2c4..1a32c29a60 100644 --- a/python/fate/arch/tensor/ops/_agg_ops.py +++ b/python/fate/arch/tensor/ops/_agg_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import overload from .._tensor import Tensor diff --git a/python/fate/arch/tensor/ops/_binary_ops.py b/python/fate/arch/tensor/ops/_binary_ops.py index cf0055d5d0..f546b3b142 100644 --- a/python/fate/arch/tensor/ops/_binary_ops.py +++ b/python/fate/arch/tensor/ops/_binary_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._ops import auto_binary_op diff --git a/python/fate/arch/tensor/ops/_matmul_ops.py b/python/fate/arch/tensor/ops/_matmul_ops.py index 65422690ce..0fa8fa3a9b 100644 --- a/python/fate/arch/tensor/ops/_matmul_ops.py +++ b/python/fate/arch/tensor/ops/_matmul_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .._tensor import DStorage, Tensor from ..types import DAxis, Shape from ._ops import _get_dispatch_info, dispatch_signature2 diff --git a/python/fate/arch/tensor/ops/_ops.py b/python/fate/arch/tensor/ops/_ops.py index d30f408fcf..e70a4ee72e 100644 --- a/python/fate/arch/tensor/ops/_ops.py +++ b/python/fate/arch/tensor/ops/_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import wraps from .._exception import OpDispatchInvalidDevice, OpsDispatchBadSignatureError diff --git a/python/fate/arch/tensor/ops/_slice_ops.py b/python/fate/arch/tensor/ops/_slice_ops.py index 00dbb53164..03f5b4ba39 100644 --- a/python/fate/arch/tensor/ops/_slice_ops.py +++ b/python/fate/arch/tensor/ops/_slice_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .._tensor import Tensor from ..types import DAxis, DStorage, Shape from ._ops import _get_dispatch_info diff --git a/python/fate/arch/tensor/ops/_unary_ops.py b/python/fate/arch/tensor/ops/_unary_ops.py index 5bb1f3360a..76b40d5eb0 100644 --- a/python/fate/arch/tensor/ops/_unary_ops.py +++ b/python/fate/arch/tensor/ops/_unary_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._ops import auto_unary_op diff --git a/python/fate/arch/tensor/storage/_helper.py b/python/fate/arch/tensor/storage/_helper.py index 53f3f0f895..a5ecb231da 100644 --- a/python/fate/arch/tensor/storage/_helper.py +++ b/python/fate/arch/tensor/storage/_helper.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class local_ops_helper: def __init__(self, device, dtype) -> None: self.device = device diff --git a/python/fate/arch/tensor/storage/_ops.py b/python/fate/arch/tensor/storage/_ops.py index 772f9a8d65..60e389fa42 100644 --- a/python/fate/arch/tensor/storage/_ops.py +++ b/python/fate/arch/tensor/storage/_ops.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable from ..types import DStorage, Storage diff --git a/python/fate/arch/tensor/storage/distributed/agg.py b/python/fate/arch/tensor/storage/distributed/agg.py index 9f7c4cd5e4..0cb0f97b1e 100644 --- a/python/fate/arch/tensor/storage/distributed/agg.py +++ b/python/fate/arch/tensor/storage/distributed/agg.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from fate.arch.tensor.types import DStorage diff --git a/python/fate/arch/tensor/storage/local/_types.py b/python/fate/arch/tensor/storage/local/_types.py index e69de29bb2..ae946a49c4 100644 --- a/python/fate/arch/tensor/storage/local/_types.py +++ b/python/fate/arch/tensor/storage/local/_types.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/arch/tensor/storage/local/device/__init__.py b/python/fate/arch/tensor/storage/local/device/__init__.py index bd5d68015d..5d2b3a4991 100644 --- a/python/fate/arch/tensor/storage/local/device/__init__.py +++ b/python/fate/arch/tensor/storage/local/device/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable from fate.arch.unify import device diff --git a/python/fate/arch/tensor/storage/local/device/cpu/__init__.py b/python/fate/arch/tensor/storage/local/device/cpu/__init__.py index e69de29bb2..ae946a49c4 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/__init__.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_base.py b/python/fate/arch/tensor/storage/local/device/cpu/_base.py index 5d57634331..8a12448e84 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/_base.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/_base.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable from fate.arch.tensor.types import LStorage, dtype diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py b/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py index ce1435811c..61d74a0438 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/_metaclass.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pickle import numpy as np diff --git a/python/fate/arch/tensor/storage/local/device/cpu/_ops.py b/python/fate/arch/tensor/storage/local/device/cpu/_ops.py index 3e48b0e047..8f56bb7e8a 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/_ops.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/_ops.py @@ -1 +1,15 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. custom_ops = dict(slice=slice) diff --git a/python/fate/arch/tensor/storage/local/device/cpu/paillier.py b/python/fate/arch/tensor/storage/local/device/cpu/paillier.py index 6ad236b212..e4c129f342 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/paillier.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/paillier.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable, List import torch diff --git a/python/fate/arch/tensor/storage/local/device/cpu/plain.py b/python/fate/arch/tensor/storage/local/device/cpu/plain.py index 338e6cb245..b012c462b3 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/plain.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/plain.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable, List import torch diff --git a/python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py b/python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py index e69de29bb2..ae946a49c4 100644 --- a/python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py +++ b/python/fate/arch/tensor/storage/local/device/cpu/plain_custom.py @@ -0,0 +1,14 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fate/arch/tensor/types/__init__.py b/python/fate/arch/tensor/types/__init__.py index 97b04ca18f..4753548d9e 100644 --- a/python/fate/arch/tensor/types/__init__.py +++ b/python/fate/arch/tensor/types/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Union from ._dstorage import DStorage diff --git a/python/fate/arch/tensor/types/_dstorage.py b/python/fate/arch/tensor/types/_dstorage.py index 785241be40..953bd7c164 100644 --- a/python/fate/arch/tensor/types/_dstorage.py +++ b/python/fate/arch/tensor/types/_dstorage.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, List, Optional from fate.arch.unify import device diff --git a/python/fate/arch/tensor/types/_dtype.py b/python/fate/arch/tensor/types/_dtype.py index 830089a526..0e3fa441fe 100644 --- a/python/fate/arch/tensor/types/_dtype.py +++ b/python/fate/arch/tensor/types/_dtype.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from enum import Enum import torch diff --git a/python/fate/arch/tensor/types/_lstorage.py b/python/fate/arch/tensor/types/_lstorage.py index 44f17416c2..64446ec599 100644 --- a/python/fate/arch/tensor/types/_lstorage.py +++ b/python/fate/arch/tensor/types/_lstorage.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Protocol from fate.arch.unify import device diff --git a/python/fate/arch/tensor/types/_shape.py b/python/fate/arch/tensor/types/_shape.py index 939fa75542..8abe0ef673 100644 --- a/python/fate/arch/tensor/types/_shape.py +++ b/python/fate/arch/tensor/types/_shape.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import reduce from typing import List, Optional, overload diff --git a/python/fate/arch/unify/__init__.py b/python/fate/arch/unify/__init__.py index 7c30654d62..7e26924368 100644 --- a/python/fate/arch/unify/__init__.py +++ b/python/fate/arch/unify/__init__.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from ._infra_def import Backend, device from ._io import URI, EggrollURI, FileURI, HdfsURI, HttpsURI, HttpURI from ._uuid import generate_computing_uuid, uuid diff --git a/python/fate/arch/unify/_infra_def.py b/python/fate/arch/unify/_infra_def.py index bf3cd344fe..adbd69b987 100644 --- a/python/fate/arch/unify/_infra_def.py +++ b/python/fate/arch/unify/_infra_def.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from enum import Enum diff --git a/python/fate/arch/unify/_io.py b/python/fate/arch/unify/_io.py index 1bd81b99e6..155aaa3bb9 100644 --- a/python/fate/arch/unify/_io.py +++ b/python/fate/arch/unify/_io.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import hashlib import re from abc import ABCMeta diff --git a/python/fate/arch/unify/_uuid.py b/python/fate/arch/unify/_uuid.py index 1558f01c91..7b9c77fa98 100644 --- a/python/fate/arch/unify/_uuid.py +++ b/python/fate/arch/unify/_uuid.py @@ -1,3 +1,17 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional from uuid import uuid1 diff --git a/python/fate_test/fate_test/__init__.py b/python/fate_test/fate_test/__init__.py deleted file mode 100644 index 878d3a9c5d..0000000000 --- a/python/fate_test/fate_test/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/python/fate_test/fate_test/_ascii.py b/python/fate_test/fate_test/_ascii.py deleted file mode 100644 index 9f87d5bd02..0000000000 --- a/python/fate_test/fate_test/_ascii.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -HEAD = """\ - -████████╗███████╗███████╗████████╗███████╗██╗ ██╗██╗████████╗███████╗ -╚══██╔══╝██╔════╝██╔════╝╚══██╔══╝██╔════╝██║ ██║██║╚══██╔══╝██╔════╝ - ██║ █████╗ ███████╗ ██║ ███████╗██║ ██║██║ ██║ █████╗ - ██║ ██╔══╝ ╚════██║ ██║ ╚════██║██║ ██║██║ ██║ ██╔══╝ - ██║ ███████╗███████║ ██║ ███████║╚██████╔╝██║ ██║ ███████╗ - ╚═╝ ╚══════╝╚══════╝ ╚═╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚══════╝ - -""" - -BENCHMARK = """\ - -██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗███╗ ███╗ █████╗ ██████╗ ██╗ ██╗ -██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║████╗ ████║██╔══██╗██╔══██╗██║ ██╔╝ -██████╔╝█████╗ ██╔██╗ ██║██║ ███████║██╔████╔██║███████║██████╔╝█████╔╝ -██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║██║╚██╔╝██║██╔══██║██╔══██╗██╔═██╗ -██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║██║ ╚═╝ ██║██║ ██║██║ ██║██║ ██╗ -╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ -""" - - -TAIL = """\ - - ██╗ ██╗ █████╗ ██╗ ██╗███████╗ ███████╗██╗ ██╗███╗ ██╗ - ██║ ██║██╔══██╗██║ ██║██╔════╝ ██╔════╝██║ ██║████╗ ██║ - ███████║███████║██║ ██║█████╗ █████╗ ██║ ██║██╔██╗ ██║ - ██╔══██║██╔══██║╚██╗ ██╔╝██╔══╝ ██╔══╝ ██║ ██║██║╚██╗██║ - ██║ ██║██║ ██║ ╚████╔╝ ███████╗ ██║ ╚██████╔╝██║ ╚████║ - ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═══╝ ╚══════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ - -""" diff --git a/python/fate_test/fate_test/_client.py b/python/fate_test/fate_test/_client.py deleted file mode 100644 index 84d623c4c3..0000000000 --- a/python/fate_test/fate_test/_client.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sshtunnel - -from fate_test._flow_client import FLOWClient -from fate_test._io import LOGGER -from fate_test._parser import Config - - -class Clients(object): - def __init__(self, config: Config): - self._flow_clients = {} - self._tunnel_id_to_flow_clients = {} - self._role_str_to_service_id = {} - self._tunnel_id_to_tunnel = config.tunnel_id_to_tunnel - - for service_id, service in config.service_id_to_service.items(): - if isinstance(service, Config.service): - self._flow_clients[service_id] = FLOWClient( - service.address, config.data_base_dir, config.cache_directory) - - elif isinstance(service, Config.tunnel_service): - self._flow_clients[service_id] = FLOWClient(None, config.data_base_dir, config.cache_directory) - self._tunnel_id_to_flow_clients.setdefault(service.tunnel_id, []).append( - (service.index, self._flow_clients[service_id])) - - for party, service_id in config.party_to_service_id.items(): - for role_str in config.parties.party_to_role_string(party): - self._role_str_to_service_id[role_str] = service_id - - def __getitem__(self, role_str: str) -> 'FLOWClient': - if role_str not in self._role_str_to_service_id: - raise RuntimeError(f"no flow client found binding to {role_str}") - return self._flow_clients[self._role_str_to_service_id[role_str]] - - def __enter__(self): - # open ssh tunnels and create flow clients for remote - self._tunnels = [] - for tunnel_id, tunnel_conf in self._tunnel_id_to_tunnel.items(): - tunnel = sshtunnel.SSHTunnelForwarder(ssh_address_or_host=tunnel_conf.ssh_address, - ssh_username=tunnel_conf.ssh_username, - ssh_password=tunnel_conf.ssh_password, - ssh_pkey=tunnel_conf.ssh_priv_key, - remote_bind_addresses=tunnel_conf.services_address) - tunnel.start() - self._tunnels.append(tunnel) - for index, flow_client in self._tunnel_id_to_flow_clients[tunnel_id]: - flow_client.set_address(f"127.0.0.1:{tunnel.local_bind_ports[index]}") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for tunnel in self._tunnels: - try: - tunnel.stop() - except Exception as e: - LOGGER.exception(e) - - def contains(self, role_str): - return role_str in self._role_str_to_service_id - - def all_roles(self): - return sorted(self._role_str_to_service_id.keys()) diff --git a/python/fate_test/fate_test/_config.py b/python/fate_test/fate_test/_config.py deleted file mode 100644 index 1c922dd46b..0000000000 --- a/python/fate_test/fate_test/_config.py +++ /dev/null @@ -1,264 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import json -import os -import typing -from collections import namedtuple -from pathlib import Path - -from ruamel import yaml - -template = """\ -# base dir for data upload conf eg, data_base_dir={FATE} -# examples/data/breast_hetero_guest.csv -> $data_base_dir/examples/data/breast_hetero_guest.csv -data_base_dir: path(FATE) - -# directory dedicated to fate_test job file storage, default cache location={FATE}/examples/cache/ -cache_directory: examples/cache/ -# directory stores performance benchmark suites, default location={FATE}/examples/benchmark_performance -performance_template_directory: examples/benchmark_performance/ -# directory stores flow test config, default location={FATE}/examples/flow_test_template/hetero_lr/flow_test_config.yaml -flow_test_config_directory: examples/flow_test_template/hetero_lr/flow_test_config.yaml - -# directory stores testsuite file with min_test data sets to upload, -# default location={FATE}/examples/data/upload_config/min_test_data_testsuite.json -min_test_data_config: examples/data/upload_config/min_test_data_testsuite.json -# directory stores testsuite file with all example data sets to upload, -# default location={FATE}/examples/data/upload_config/all_examples_data_testsuite.json -all_examples_data_config: examples/data/upload_config/all_examples_data_testsuite.json - -# directory where FATE code locates, default installation location={FATE}/fate -# python/federatedml -> $fate_base/python/federatedml -fate_base: path(FATE)/fate - -# whether to delete data in suites after all jobs done -clean_data: true - -# participating parties' id and correponding flow service ip & port information -parties: - guest: [9999] - host: [10000, 9999] - arbiter: [10000] -services: - - flow_services: - - {address: 127.0.0.1:9380, parties: [9999, 10000]} - serving_setting: - address: 127.0.0.1:8059 - - ssh_tunnel: # optional - enable: false - ssh_address: : - ssh_username: - ssh_password: # optional - ssh_priv_key: "~/.ssh/id_rsa" - - -# what is ssh_tunnel? -# to open the ssh tunnel(s) if the remote service -# cannot be accessed directly from the location where the test suite is run! -# -# +---------------------+ -# | ssh address | -# | ssh username | -# | ssh password/ | -# +--------+ | ssh priv_key | +----------------+ -# |local ip+----------ssh tuunel-------------->+remote local ip | -# +--------+ | | +----------------+ -# | | -# request local ip:port +----- as if --------->request remote's local ip:port from remote side -# | | -# | | -# +---------------------+ -# - -""" - -data_base_dir = Path(__file__).resolve().parents[3] -if (data_base_dir / 'examples').is_dir(): - template = template.replace('path(FATE)', str(data_base_dir)) - -_default_config = Path(__file__).resolve().parent / 'fate_test_config.yaml' - -data_switch = None -use_local_data = 1 -data_alter = dict() -deps_alter = dict() -jobs_num = 0 -jobs_progress = 0 -non_success_jobs = [] - - -def create_config(path: Path, override=False): - if path.exists() and not override: - raise FileExistsError(f"{path} exists") - - with path.open("w") as f: - f.write(template) - - -def default_config(): - if not _default_config.exists(): - create_config(_default_config) - return _default_config - - -class Parties(object): - def __init__(self, **kwargs): - """ - mostly, accept guest, host and arbiter - """ - self._role_to_parties = kwargs - - self._party_to_role_string = {} - for role in kwargs: - parties = kwargs[role] - setattr(self, role, parties) - for i, party in enumerate(parties): - if party not in self._party_to_role_string: - self._party_to_role_string[party] = set() - self._party_to_role_string[party].add(f"{role.lower()}_{i}") - - @staticmethod - def from_dict(d: typing.MutableMapping[str, typing.List[int]]): - return Parties(**d) - - def party_to_role_string(self, party): - return self._party_to_role_string[party] - - def extract_role(self, counts: typing.MutableMapping[str, int]): - roles = {} - for role, num in counts.items(): - if role not in self._role_to_parties and num > 0: - raise ValueError(f"{role} not found in config") - else: - if len(self._role_to_parties[role]) < num: - raise ValueError(f"require {num} {role} parties, only {len(self._role_to_parties[role])} in config") - roles[role] = self._role_to_parties[role][:num] - return roles - - def extract_initiator_role(self, role): - initiator_role = role.strip() - if len(self._role_to_parties[initiator_role]) < 1: - raise ValueError(f"role {initiator_role} has empty party list") - party_id = self._role_to_parties[initiator_role][0] - return dict(role=initiator_role, party_id=party_id) - - -class Config(object): - service = namedtuple("service", ["address"]) - tunnel_service = namedtuple("tunnel_service", ["tunnel_id", "index"]) - tunnel = namedtuple("tunnel", ["ssh_address", "ssh_username", "ssh_password", "ssh_priv_key", "services_address"]) - - def __init__(self, config): - self.data_base_dir = config["data_base_dir"] - self.cache_directory = os.path.join(config["data_base_dir"], config["cache_directory"]) - self.perf_template_dir = os.path.join(config["data_base_dir"], config["performance_template_directory"]) - self.flow_test_config_dir = os.path.join(config["data_base_dir"], config["flow_test_config_directory"]) - self.min_test_data_config = os.path.join(config["data_base_dir"], config["min_test_data_config"]) - self.all_examples_data_config = os.path.join(config["data_base_dir"], config["all_examples_data_config"]) - self.fate_base = config["fate_base"] - self.clean_data = config.get("clean_data", True) - self.parties = Parties.from_dict(config["parties"]) - self.role = config["parties"] - self.serving_setting = config["services"][0] - self.party_to_service_id = {} - self.service_id_to_service = {} - self.tunnel_id_to_tunnel = {} - self.extend_sid = None - self.auto_increasing_sid = None - - tunnel_id = 0 - service_id = 0 - os.makedirs(os.path.dirname(self.cache_directory), exist_ok=True) - for service_config in config["services"]: - flow_services = service_config["flow_services"] - if service_config.get("ssh_tunnel", {}).get("enable", False): - tunnel_id += 1 - services_address = [] - for index, flow_service in enumerate(flow_services): - service_id += 1 - address_host, address_port = flow_service["address"].split(":") - address_port = int(address_port) - services_address.append((address_host, address_port)) - self.service_id_to_service[service_id] = self.tunnel_service(tunnel_id, index) - for party in flow_service["parties"]: - self.party_to_service_id[party] = service_id - tunnel_config = service_config["ssh_tunnel"] - ssh_address_host, ssh_address_port = tunnel_config["ssh_address"].split(":") - self.tunnel_id_to_tunnel[tunnel_id] = self.tunnel((ssh_address_host, int(ssh_address_port)), - tunnel_config["ssh_username"], - tunnel_config["ssh_password"], - tunnel_config["ssh_priv_key"], - services_address) - else: - for flow_service in flow_services: - service_id += 1 - address = flow_service["address"] - self.service_id_to_service[service_id] = self.service(address) - for party in flow_service["parties"]: - self.party_to_service_id[party] = service_id - - @staticmethod - def load(path: typing.Union[str, Path], **kwargs): - if isinstance(path, str): - path = Path(path) - config = {} - if path is not None: - with path.open("r") as f: - config.update(yaml.safe_load(f)) - - if config["data_base_dir"] == "path(FATE)": - raise ValueError("Invalid 'data_base_dir'.") - config["data_base_dir"] = path.resolve().joinpath(config["data_base_dir"]).resolve() - - config.update(kwargs) - return Config(config) - - @staticmethod - def load_from_file(path: typing.Union[str, Path]): - """ - Loads conf content from json or yaml file. Used to read in parameter configuration - Parameters - ---------- - path: str, path to conf file, should be absolute path - - Returns - ------- - dict, parameter configuration in dictionary format - - """ - if isinstance(path, str): - path = Path(path) - config = {} - if path is not None: - file_type = path.suffix - with path.open("r") as f: - if file_type == ".yaml": - config.update(yaml.safe_load(f)) - elif file_type == ".json": - config.update(json.load(f)) - else: - raise ValueError(f"Cannot load conf from file type {file_type}") - return config - - -def parse_config(config): - try: - config_inst = Config.load(config) - except Exception as e: - raise RuntimeError(f"error parse config from {config}") from e - return config_inst diff --git a/python/fate_test/fate_test/_flow_client.py b/python/fate_test/fate_test/_flow_client.py deleted file mode 100644 index a820cbec1b..0000000000 --- a/python/fate_test/fate_test/_flow_client.py +++ /dev/null @@ -1,447 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import json -import time -import typing -from datetime import timedelta -from pathlib import Path - -import requests -from fate_test._parser import Data, Job -from flow_sdk.client import FlowClient -from fate_test import _config - - -class FLOWClient(object): - - def __init__(self, - address: typing.Optional[str], - data_base_dir: typing.Optional[Path], - cache_directory: typing.Optional[Path]): - self.address = address - self.version = "v1" - self._http = requests.Session() - self._data_base_dir = data_base_dir - self._cache_directory = cache_directory - self.data_size = 0 - - def set_address(self, address): - self.address = address - - def upload_data(self, data: Data, callback=None, output_path=None): - try: - response, data_path, bind = self._upload_data(conf=data.config, output_path=output_path, verbose=0, drop=1) - if callback is not None: - callback(response) - if not bind: - status = self._awaiting(response.job_id, "local") - status = str(status).lower() - else: - status = response["retmsg"] - except Exception as e: - raise RuntimeError(f"upload data failed") from e - return status, data_path - - def delete_data(self, data: Data): - try: - table_name = data.config['table_name'] if data.config.get( - 'table_name', None) is not None else data.config.get('name') - self._delete_data(table_name=table_name, namespace=data.config['namespace']) - except Exception as e: - raise RuntimeError(f"delete data failed") from e - - def submit_job(self, job: Job, callback=None) -> 'SubmitJobResponse': - try: - response = self._submit_job(**job.submit_params) - if callback is not None: - callback(response) - status = self._awaiting(response.job_id, "guest", callback) - response.status = status - - except Exception as e: - raise RuntimeError(f"submit job failed") from e - return response - - def deploy_model(self, model_id, model_version, dsl=None): - result = self._deploy_model(model_id=model_id, model_version=model_version, dsl=dsl) - return result - - def output_data_table(self, job_id, role, party_id, component_name): - result = self._output_data_table(job_id=job_id, role=role, party_id=party_id, component_name=component_name) - return result - - def table_info(self, table_name, namespace): - result = self._table_info(table_name=table_name, namespace=namespace) - return result - - def add_notes(self, job_id, role, party_id, notes): - self._add_notes(job_id=job_id, role=role, party_id=party_id, notes=notes) - - def check_connection(self): - try: - version = self._http.request(method="POST", url=f"{self._base}version/get", json={"module": "FATE"}, - timeout=2).json() - except Exception: - import traceback - traceback.print_exc() - raise - fate_version = version.get("data", {}).get("FATE") - if fate_version: - return fate_version, self.address - - raise EnvironmentError(f"connection not ok") - - def _awaiting(self, job_id, role, callback=None): - while True: - response = self._query_job(job_id, role=role) - if response.status.is_done(): - return response.status - if callback is not None: - callback(response) - time.sleep(1) - - def _save_json(self, file, file_name): - """ - file = json.dumps(file, indent=4) - file_path = os.path.join(str(self._cache_directory), file_name) - try: - with open(file_path, "w", encoding='utf-8') as f: - f.write(file) - except Exception as e: - raise Exception(f"write error==>{e}") - return file_path - """ - return file - - def _upload_data(self, conf, output_path=None, verbose=0, drop=1): - if conf.get("engine", {}) != "PATH": - if output_path is not None: - conf['file'] = os.path.join(os.path.abspath(output_path), os.path.basename(conf.get('file'))) - else: - if _config.data_switch is not None: - conf['file'] = os.path.join(str(self._cache_directory), os.path.basename(conf.get('file'))) - else: - conf['file'] = os.path.join(str(self._data_base_dir), conf.get('file')) - path = Path(conf.get('file')) - if not path.exists(): - raise Exception('The file is obtained from the fate flow client machine, but it does not exist, ' - f'please check the path: {path}') - upload_response = self.flow_client(request='data/upload', param=self._save_json(conf, 'upload_conf.json'), - verbose=verbose, drop=drop) - response = UploadDataResponse(upload_response) - return response, conf['file'], False - else: - if _config.data_switch is not None: - conf['address']['path'] = os.path.join(str(self._cache_directory), conf['address']['path']) - else: - conf['address']['path'] = os.path.join(str(self._data_base_dir), conf['address']['path']) - conf['drop'] = drop - del conf["extend_sid"] - del conf["auto_increasing_sid"] - del conf["use_local_data"] - path = Path(conf.get('address').get('path')) - self._table_bind(conf) - if not path.exists(): - raise Exception('The file is obtained from the fate flow client machine, but it does not exist, ' - f'please check the path: {path}') - response = self._table_bind(conf) - return response, None, True - - def _table_info(self, table_name, namespace): - param = { - 'table_name': table_name, - 'namespace': namespace - } - response = self.flow_client(request='table/info', param=param) - return response - - def _delete_data(self, table_name, namespace): - param = { - 'table_name': table_name, - 'namespace': namespace - } - response = self.flow_client(request='table/delete', param=param) - return response - - def _submit_job(self, conf, dsl): - param = { - 'job_dsl': self._save_json(dsl, 'submit_dsl.json'), - 'job_runtime_conf': self._save_json(conf, 'submit_conf.json') - } - response = SubmitJobResponse(self.flow_client(request='job/submit', param=param)) - return response - - def _deploy_model(self, model_id, model_version, dsl=None): - post_data = {'model_id': model_id, - 'model_version': model_version, - 'predict_dsl': dsl} - response = self.flow_client(request='model/deploy', param=post_data) - result = {} - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"deploy model error: {response}") - result["model_id"] = response["data"]["model_id"] - result["model_version"] = response["data"]["model_version"] - except Exception as e: - raise RuntimeError(f"deploy model error: {response}") from e - - return result - - def _output_data_table(self, job_id, role, party_id, component_name): - post_data = {'job_id': job_id, - 'role': role, - 'party_id': party_id, - 'component_name': component_name} - response = self.flow_client(request='component/output_data_table', param=post_data) - result = {} - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"deploy model error: {response}") - result["name"] = response["data"][0]["table_name"] - result["namespace"] = response["data"][0]["table_namespace"] - except Exception as e: - raise RuntimeError(f"output data table error: {response}") from e - return result - - def _get_summary(self, job_id, role, party_id, component_name): - post_data = {'job_id': job_id, - 'role': role, - 'party_id': party_id, - 'component_name': component_name} - response = self.flow_client(request='component/get_summary', param=post_data) - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - result = {} - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"deploy model error: {response}") - result["summary_dir"] = retmsg # 获取summary文件位置 - except Exception as e: - raise RuntimeError(f"output data table error: {response}") from e - return result - - def _query_job(self, job_id, role): - param = { - 'job_id': job_id, - 'role': role - } - response = QueryJobResponse(self.flow_client(request='job/query', param=param)) - return response - - def get_version(self): - response = self._post(url='version/get', json={"module": "FATE"}) - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"get version error: {response}") - fate_version = response["data"]["FATE"] - except Exception as e: - raise RuntimeError(f"get version error: {response}") from e - return fate_version - - def _add_notes(self, job_id, role, party_id, notes): - data = dict(job_id=job_id, role=role, party_id=party_id, notes=notes) - response = AddNotesResponse(self._post(url='job/update', json=data)) - return response - - def _table_bind(self, data): - response = self._post(url='table/bind', json=data) - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"table bind error: {response}") - except Exception as e: - raise RuntimeError(f"table bind error: {response}") from e - return response - - @property - def _base(self): - return f"http://{self.address}/{self.version}/" - - def _post(self, url, **kwargs) -> dict: - request_url = self._base + url - try: - response = self._http.request(method='post', url=request_url, **kwargs) - except Exception as e: - raise RuntimeError(f"post {url} with {kwargs} failed") from e - - try: - if isinstance(response, requests.models.Response): - response = response.json() - else: - try: - response = json.loads(response.content.decode('utf-8', 'ignore'), strict=False) - except (TypeError, ValueError): - return response - except json.decoder.JSONDecodeError: - response = {'retcode': 100, - 'retmsg': "Internal server error. Nothing in response. You may check out the configuration in " - "'FATE/conf/service_conf.yaml' and restart fate flow server."} - return response - - def flow_client(self, request, param, verbose=0, drop=0): - client = FlowClient(self.address.split(':')[0], self.address.split(':')[1], self.version) - if request == 'data/upload': - stdout = client.data.upload(config_data=param, verbose=verbose, drop=drop) - elif request == 'table/delete': - stdout = client.table.delete(table_name=param['table_name'], namespace=param['namespace']) - elif request == 'table/info': - stdout = client.table.info(table_name=param['table_name'], namespace=param['namespace']) - elif request == 'job/submit': - stdout = client.job.submit(config_data=param['job_runtime_conf'], dsl_data=param['job_dsl']) - elif request == 'job/query': - stdout = client.job.query(job_id=param['job_id'], role=param['role']) - elif request == 'model/deploy': - stdout = client.model.deploy(model_id=param['model_id'], model_version=param['model_version'], - predict_dsl=param['predict_dsl']) - elif request == 'component/output_data_table': - stdout = client.component.output_data_table(job_id=param['job_id'], role=param['role'], - party_id=param['party_id'], - component_name=param['component_name']) - elif request == 'component/get_summary': - stdout = client.component.get_summary(job_id=param['job_id'], role=param['role'], - party_id=param['party_id'], - component_name=param['component_name']) - - else: - stdout = {"retcode": None} - - status = stdout["retcode"] - if status != 0: - if request == 'table/delete' and stdout["retmsg"] == "no find table": - return stdout - raise ValueError({'retcode': 100, 'retmsg': stdout["retmsg"]}) - - return stdout - - -class Status(object): - def __init__(self, status: str): - self.status = status - - def is_done(self): - return self.status.lower() in ['complete', 'success', 'canceled', 'failed', "timeout"] - - def is_success(self): - return self.status.lower() in ['complete', 'success'] - - def __str__(self): - return self.status - - def __repr__(self): - return self.__str__() - - -class QueryJobResponse(object): - def __init__(self, response: dict): - try: - status = Status(response.get('data')[0]["f_status"]) - progress = response.get('data')[0]['f_progress'] - except Exception as e: - raise RuntimeError(f"query job error, response: {response}") from e - self.status = status - self.progress = progress - - -class UploadDataResponse(object): - def __init__(self, response: dict): - try: - self.job_id = response["jobId"] - except Exception as e: - raise RuntimeError(f"upload error, response: {response}") from e - self.status: typing.Optional[Status] = None - - -class AddNotesResponse(object): - def __init__(self, response: dict): - try: - retcode = response['retcode'] - retmsg = response['retmsg'] - if retcode != 0 or retmsg != 'success': - raise RuntimeError(f"add notes error: {response}") - except Exception as e: - raise RuntimeError(f"add notes error: {response}") from e - - -class SubmitJobResponse(object): - def __init__(self, response: dict): - try: - self.job_id = response["jobId"] - self.model_info = response["data"]["model_info"] - except Exception as e: - raise RuntimeError(f"submit job error, response: {response}") from e - self.status: typing.Optional[Status] = None - - -class DataProgress(object): - def __init__(self, role_str): - self.role_str = role_str - self.start = time.time() - self.show_str = f"[{self.elapse()}] {self.role_str}" - self.job_id = "" - - def elapse(self): - return f"{timedelta(seconds=int(time.time() - self.start))}" - - def submitted(self, job_id): - self.job_id = job_id - self.show_str = f"[{self.elapse()}]{self.job_id} {self.role_str}" - - def update(self): - self.show_str = f"[{self.elapse()}]{self.job_id} {self.role_str}" - - def show(self): - return self.show_str - - -class JobProgress(object): - def __init__(self, name): - self.name = name - self.start = time.time() - self.show_str = f"[{self.elapse()}] {self.name}" - self.job_id = "" - self.progress_tracking = "" - - def elapse(self): - return f"{timedelta(seconds=int(time.time() - self.start))}" - - def set_progress_tracking(self, progress_tracking): - self.progress_tracking = progress_tracking + " " - - def submitted(self, job_id): - self.job_id = job_id - self.show_str = f"{self.progress_tracking}[{self.elapse()}]{self.job_id} submitted {self.name}" - - def running(self, status, progress): - if progress is None: - progress = 0 - self.show_str = f"{self.progress_tracking}[{self.elapse()}]{self.job_id} {status} {progress:3}% {self.name}" - - def exception(self, exception_id): - self.show_str = f"{self.progress_tracking}[{self.elapse()}]{self.name} exception({exception_id}): {self.job_id}" - - def final(self, status): - self.show_str = f"{self.progress_tracking}[{self.elapse()}]{self.job_id} {status} {self.name}" - - def show(self): - return self.show_str diff --git a/python/fate_test/fate_test/_io.py b/python/fate_test/fate_test/_io.py deleted file mode 100644 index edfaeee964..0000000000 --- a/python/fate_test/fate_test/_io.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import click -import loguru - -from fate_test._ascii import HEAD, TAIL, BENCHMARK - - -# noinspection PyPep8Naming -class echo(object): - _file = None - - @classmethod - def set_file(cls, file): - cls._file = file - - @classmethod - def echo(cls, message, **kwargs): - click.secho(message, **kwargs) - click.secho(message, file=cls._file, **kwargs) - - @classmethod - def file(cls, message, **kwargs): - click.secho(message, file=cls._file, **kwargs) - - @classmethod - def stdout(cls, message, **kwargs): - click.secho(message, **kwargs) - - @classmethod - def stdout_newline(cls): - click.secho("") - - @classmethod - def welcome(cls, banner_type="testsuite"): - if banner_type == "testsuite": - cls.echo(HEAD) - elif banner_type == "benchmark": - cls.echo(BENCHMARK) - - @classmethod - def farewell(cls): - cls.echo(TAIL) - - @classmethod - def flush(cls): - import sys - sys.stdout.flush() - - -def set_logger(name): - loguru.logger.remove() - loguru.logger.add(name, level='ERROR', delay=True) - return loguru.logger - - -LOGGER = loguru.logger diff --git a/python/fate_test/fate_test/_parser.py b/python/fate_test/fate_test/_parser.py deleted file mode 100644 index 01276fa932..0000000000 --- a/python/fate_test/fate_test/_parser.py +++ /dev/null @@ -1,577 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import json -import typing -from collections import deque -from pathlib import Path -import click -import prettytable - -from fate_test import _config -from fate_test._io import echo -from fate_test._config import Parties, Config -from fate_test.utils import TxtStyle - - -# noinspection PyPep8Naming -class chain_hook(object): - def __init__(self): - self._hooks = [] - - def add_hook(self, hook): - self._hooks.append(hook) - return self - - def add_extend_namespace_hook(self, namespace): - self.add_hook(_namespace_hook(namespace)) - return self - - def add_replace_hook(self, mapping): - self.add_hook(_replace_hook(mapping)) - - def hook(self, d): - return self._chain_hooks(self._hooks, d) - - @staticmethod - def _chain_hooks(hook_funcs, d): - for hook_func in hook_funcs: - if d is None: - return - d = hook_func(d) - return d - - -DATA_JSON_HOOK = chain_hook() -CONF_JSON_HOOK = chain_hook() -DSL_JSON_HOOK = chain_hook() - - -class Data(object): - def __init__(self, config: dict, role_str: str): - self.config = config - self.role_str = role_str - - @staticmethod - def load(config, path: Path): - kwargs = {} - for field_name in config.keys(): - if field_name not in ["file", "role"]: - kwargs[field_name] = config[field_name] - if config.get("engine", {}) != "PATH": - file_path = path.parent.joinpath(config["file"]).resolve() - if not file_path.exists(): - kwargs["file"] = config["file"] - else: - kwargs["file"] = file_path - role_str = config.get("role") if config.get("role") != "guest" else "guest_0" - return Data(config=kwargs, role_str=role_str) - - def update(self, config: Config): - self.config.update(dict(extend_sid=config.extend_sid, - auto_increasing_sid=config.auto_increasing_sid)) - - -class JobConf(object): - def __init__(self, initiator: dict, role: dict, job_parameters=None, **kwargs): - self.initiator = initiator - self.role = role - self.job_parameters = job_parameters if job_parameters else {} - self.others_kwargs = kwargs - - def as_dict(self): - return dict( - initiator=self.initiator, - role=self.role, - job_parameters=self.job_parameters, - **self.others_kwargs, - ) - - @staticmethod - def load(path: Path): - with path.open("r") as f: - kwargs = json.load(f, object_hook=CONF_JSON_HOOK.hook) - return JobConf(**kwargs) - - @property - def dsl_version(self): - return self.others_kwargs.get("dsl_version", 1) - - def update( - self, - parties: Parties, - timeout, - job_parameters, - component_parameters, - ): - self.initiator = parties.extract_initiator_role(self.initiator["role"]) - self.role = parties.extract_role( - {role: len(parties) for role, parties in self.role.items()} - ) - if timeout > 0: - self.update_job_common_parameters(timeout=timeout) - - if timeout > 0: - self.update_job_common_parameters(timeout=timeout) - - for key, value in job_parameters.items(): - self.update_parameters(parameters=self.job_parameters, key=key, value=value) - for key, value in component_parameters.items(): - if self.dsl_version == 1: - self.update_parameters( - parameters=self.others_kwargs.get("algorithm_parameters"), - key=key, - value=value, - ) - else: - self.update_parameters( - parameters=self.others_kwargs.get("component_parameters"), - key=key, - value=value, - ) - - def update_parameters(self, parameters, key, value): - if isinstance(parameters, dict): - for keys in parameters: - if keys == key: - parameters.get(key).update(value), - elif isinstance(parameters[keys], dict): - self.update_parameters(parameters[keys], key, value) - - def update_job_common_parameters(self, **kwargs): - if self.dsl_version == 1: - self.job_parameters.update(**kwargs) - else: - self.job_parameters.setdefault("common", {}).update(**kwargs) - - def update_job_type(self, job_type="predict"): - if self.dsl_version == 1: - if self.job_parameters.get("job_type", None) is None: - self.job_parameters.update({"job_type": job_type}) - else: - if self.job_parameters.setdefault("common", {}).get("job_type", None) is None: - self.job_parameters.setdefault("common", {}).update({"job_type": job_type}) - - def update_component_parameters(self, key, value, parameters=None): - if parameters is None: - if self.dsl_version == 1: - parameters = self.others_kwargs.get("algorithm_parameters") - else: - parameters = self.others_kwargs.get("component_parameters") - if isinstance(parameters, dict): - for keys in parameters: - if keys == key: - if isinstance(value, dict): - parameters[keys].update(value) - else: - parameters.update({key: value}) - elif ( - isinstance(parameters[keys], dict) and parameters[keys] is not None - ): - self.update_component_parameters(key, value, parameters[keys]) - - def get_component_parameters(self, keys): - if len(keys) == 0: - return self.others_kwargs.get("component_parameters") if self.dsl_version == 2 else self.others_kwargs.get( - "role_parameters") - if self.dsl_version == 1: - parameters = self.others_kwargs.get("role_parameters") - else: - parameters = self.others_kwargs.get("component_parameters").get("role") - - for key in keys: - parameters = parameters[key] - return parameters - - -class JobDSL(object): - def __init__(self, components: dict, provider=None): - self.components = components - self.provider = provider - - @staticmethod - def load(path: Path, provider): - with path.open("r") as f: - kwargs = json.load(f, object_hook=DSL_JSON_HOOK.hook) - if provider is not None: - kwargs["provider"] = provider - return JobDSL(**kwargs) - - def as_dict(self): - if self.provider is None: - return dict(components=self.components) - else: - return dict(components=self.components, provider=self.provider) - - -class Job(object): - def __init__( - self, - job_name: str, - job_conf: JobConf, - job_dsl: typing.Optional[JobDSL], - pre_works: list, - ): - self.job_name = job_name - self.job_conf = job_conf - self.job_dsl = job_dsl - self.pre_works = pre_works - - @classmethod - def load(cls, job_name, job_configs, base: Path, provider): - job_conf = JobConf.load(base.joinpath(job_configs.get("conf")).resolve()) - job_dsl = job_configs.get("dsl", None) - if job_dsl is not None: - job_dsl = JobDSL.load(base.joinpath(job_dsl).resolve(), provider) - - pre_works = [] - pre_works_value = {} - deps_dict = {} - - if job_configs.get("model_deps", None): - pre_works.append(job_configs["model_deps"]) - deps_dict["model_deps"] = {'name': job_configs["model_deps"]} - elif job_configs.get("deps", None): - pre_works.append(job_configs["deps"]) - deps_dict["model_deps"] = {'name': job_configs["deps"]} - if job_configs.get("data_deps", None): - deps_dict["data_deps"] = {'data': job_configs["data_deps"]} - pre_works.append(list(job_configs["data_deps"].keys())[0]) - deps_dict["data_deps"].update({'name': list(job_configs["data_deps"].keys())}) - if job_configs.get("cache_deps", None): - pre_works.append(job_configs["cache_deps"]) - deps_dict["cache_deps"] = {'name': job_configs["cache_deps"]} - if job_configs.get("model_loader_deps", None): - pre_works.append(job_configs["model_loader_deps"]) - deps_dict["model_loader_deps"] = {'name': job_configs["model_loader_deps"]} - - pre_works_value.update(deps_dict) - _config.deps_alter[job_name] = pre_works_value - - return Job( - job_name=job_name, job_conf=job_conf, job_dsl=job_dsl, pre_works=pre_works - ) - - @property - def submit_params(self): - return dict( - conf=self.job_conf.as_dict(), - dsl=self.job_dsl.as_dict() if self.job_dsl else None, - ) - - def set_pre_work(self, name, **kwargs): - self.job_conf.update_job_common_parameters(**kwargs) - self.job_conf.update_job_type("predict") - - def set_input_data(self, hierarchys, table_info): - for table_name, hierarchy in zip(table_info, hierarchys): - key = list(table_name.keys())[0] - value = table_name[key] - self.job_conf.update_component_parameters( - key=key, - value=value, - parameters=self.job_conf.get_component_parameters(hierarchy), - ) - - def is_submit_ready(self): - return len(self.pre_works) == 0 - - -class PipelineJob(object): - def __init__(self, job_name: str, script_path: Path): - self.job_name = job_name - self.script_path = script_path - - -class Testsuite(object): - def __init__( - self, - dataset: typing.List[Data], - jobs: typing.List[Job], - pipeline_jobs: typing.List[PipelineJob], - path: Path, - ): - self.dataset = dataset - self.jobs = jobs - self.pipeline_jobs = pipeline_jobs - self.path = path - self.suite_name = Path(self.path).stem - - self._dependency: typing.MutableMapping[str, typing.List[Job]] = {} - self._final_status: typing.MutableMapping[str, FinalStatus] = {} - self._ready_jobs = deque() - for job in self.jobs: - for name in job.pre_works: - self._dependency.setdefault(name, []).append(job) - - self._final_status[job.job_name] = FinalStatus(job.job_name) - if job.is_submit_ready(): - self._ready_jobs.appendleft(job) - - for job in self.pipeline_jobs: - self._final_status[job.job_name] = FinalStatus(job.job_name) - - @staticmethod - def load(path: Path, provider): - with path.open("r") as f: - testsuite_config = json.load(f, object_hook=DATA_JSON_HOOK.hook) - - dataset = [] - for d in testsuite_config.get("data"): - if "use_local_data" not in d: - d.update({"use_local_data": _config.use_local_data}) - dataset.append(Data.load(d, path)) - jobs = [] - for job_name, job_configs in testsuite_config.get("tasks", {}).items(): - jobs.append( - Job.load(job_name=job_name, job_configs=job_configs, base=path.parent, provider=provider) - ) - - pipeline_jobs = [] - if testsuite_config.get("pipeline_tasks", None) is not None and provider is not None: - echo.echo('[Warning] Pipeline does not support parameter: provider-> {}'.format(provider)) - for job_name, job_configs in testsuite_config.get("pipeline_tasks", {}).items(): - script_path = path.parent.joinpath(job_configs["script"]).resolve() - pipeline_jobs.append(PipelineJob(job_name, script_path)) - - testsuite = Testsuite(dataset, jobs, pipeline_jobs, path) - return testsuite - - def jobs_iter(self) -> typing.Generator[Job, None, None]: - while self._ready_jobs: - yield self._ready_jobs.pop() - - @staticmethod - def style_table(txt): - colored_txt = txt.replace("success", f"{TxtStyle.TRUE_VAL}success{TxtStyle.END}") - colored_txt = colored_txt.replace("failed", f"{TxtStyle.FALSE_VAL}failed{TxtStyle.END}") - colored_txt = colored_txt.replace("not submitted", f"{TxtStyle.FALSE_VAL}not submitted{TxtStyle.END}") - return colored_txt - - def pretty_final_summary(self, time_consuming, suite_file=None): - """table = prettytable.PrettyTable( - ["job_name", "job_id", "status", "time_consuming", "exception_id", "rest_dependency"] - )""" - table = prettytable.PrettyTable() - table.set_style(prettytable.ORGMODE) - field_names = ["job_name", "job_id", "status", "time_consuming", "exception_id", "rest_dependency"] - table.field_names = field_names - for status in self.get_final_status().values(): - if status.status != "success": - status.suite_file = suite_file - _config.non_success_jobs.append(status) - if status.exception_id != "-": - exception_id_txt = f"{TxtStyle.FALSE_VAL}{status.exception_id}{TxtStyle.END}" - else: - exception_id_txt = f"{TxtStyle.FIELD_VAL}{status.exception_id}{TxtStyle.END}" - table.add_row( - [ - f"{TxtStyle.FIELD_VAL}{status.name}{TxtStyle.END}", - f"{TxtStyle.FIELD_VAL}{status.job_id}{TxtStyle.END}", - self.style_table(status.status), - f"{TxtStyle.FIELD_VAL}{time_consuming.pop(0) if status.job_id != '-' else '-'}{TxtStyle.END}", - f"{exception_id_txt}", - f"{TxtStyle.FIELD_VAL}{','.join(status.rest_dependency)}{TxtStyle.END}", - ] - ) - - return table.get_string(title=f"{TxtStyle.TITLE}Testsuite Summary: {self.suite_name}{TxtStyle.END}") - - def model_in_dep(self, name): - return name in self._dependency - - def get_dependent_jobs(self, name): - return self._dependency[name] - - def remove_dependency(self, name): - del self._dependency[name] - - def feed_dep_info(self, job, name, model_info=None, table_info=None, cache_info=None, model_loader_info=None): - if model_info is not None: - job.set_pre_work(name, **model_info) - if table_info is not None: - job.set_input_data(table_info["hierarchy"], table_info["table_info"]) - if cache_info is not None: - job.set_input_data(cache_info["hierarchy"], cache_info["cache_info"]) - if model_loader_info is not None: - job.set_input_data(model_loader_info["hierarchy"], model_loader_info["model_loader_info"]) - if name in job.pre_works: - job.pre_works.remove(name) - if job.is_submit_ready(): - self._ready_jobs.appendleft(job) - - def reflash_configs(self, config: Config): - failed = [] - for job in self.jobs: - try: - job.job_conf.update( - config.parties, None, {}, {} - ) - except ValueError as e: - failed.append((job, e)) - return failed - - def update_status( - self, job_name, job_id: str = None, status: str = None, exception_id: str = None - ): - for k, v in locals().items(): - if k != "job_name" and v is not None: - setattr(self._final_status[job_name], k, v) - - def get_final_status(self): - for name, jobs in self._dependency.items(): - for job in jobs: - self._final_status[job.job_name].rest_dependency.append(name) - return self._final_status - - -class FinalStatus(object): - def __init__( - self, - name: str, - job_id: str = "-", - status: str = "not submitted", - exception_id: str = "-", - rest_dependency: typing.List[str] = None, - ): - self.name = name - self.job_id = job_id - self.status = status - self.exception_id = exception_id - self.rest_dependency = rest_dependency or [] - self.suite_file = None - - -class BenchmarkJob(object): - def __init__(self, job_name: str, script_path: Path, conf_path: Path): - self.job_name = job_name - self.script_path = script_path - self.conf_path = conf_path - - -class BenchmarkPair(object): - def __init__( - self, pair_name: str, jobs: typing.List[BenchmarkJob], compare_setting: dict - ): - self.pair_name = pair_name - self.jobs = jobs - self.compare_setting = compare_setting - - -class BenchmarkSuite(object): - def __init__( - self, dataset: typing.List[Data], pairs: typing.List[BenchmarkPair], path: Path - ): - self.dataset = dataset - self.pairs = pairs - self.path = path - - @staticmethod - def load(path: Path): - with path.open("r") as f: - testsuite_config = json.load(f, object_hook=DATA_JSON_HOOK.hook) - - dataset = [] - for d in testsuite_config.get("data"): - dataset.append(Data.load(d, path)) - - pairs = [] - for pair_name, pair_configs in testsuite_config.items(): - if pair_name == "data": - continue - jobs = [] - for job_name, job_configs in pair_configs.items(): - if job_name == "compare_setting": - continue - script_path = path.parent.joinpath(job_configs["script"]).resolve() - if job_configs.get("conf"): - conf_path = path.parent.joinpath(job_configs["conf"]).resolve() - else: - conf_path = "" - jobs.append( - BenchmarkJob( - job_name=job_name, script_path=script_path, conf_path=conf_path - ) - ) - compare_setting = pair_configs.get("compare_setting") - if compare_setting and not isinstance(compare_setting, dict): - raise ValueError( - f"expected 'compare_setting' type is dict, received {type(compare_setting)} instead." - ) - pairs.append( - BenchmarkPair( - pair_name=pair_name, jobs=jobs, compare_setting=compare_setting - ) - ) - suite = BenchmarkSuite(dataset=dataset, pairs=pairs, path=path) - return suite - - -def non_success_summary(): - status = {} - for job in _config.non_success_jobs: - if job.status not in status.keys(): - status[job.status] = prettytable.PrettyTable( - ["testsuite_name", "job_name", "job_id", "status", "exception_id", "rest_dependency"] - ) - - status[job.status].add_row( - [ - job.suite_file, - job.name, - job.job_id, - job.status, - job.exception_id, - ",".join(job.rest_dependency), - ] - ) - for k, v in status.items(): - echo.echo("\n" + "#" * 60) - echo.echo(v.get_string(title=f"{k} job record"), fg='red') - - -def _namespace_hook(namespace): - def _hook(d): - if d is None: - return d - if "namespace" in d and namespace: - d["namespace"] = f"{d['namespace']}_{namespace}" - return d - - return _hook - - -def _replace_hook(mapping: dict): - def _hook(d): - for k, v in mapping.items(): - if k in d: - d[k] = v - return d - - return _hook - - -class JsonParamType(click.ParamType): - name = "json_string" - - def convert(self, value, param, ctx): - try: - return json.loads(value) - except ValueError: - self.fail(f"{value} is not a valid json string", param, ctx) - - -JSON_STRING = JsonParamType() diff --git a/python/fate_test/fate_test/fate_test_config.yaml b/python/fate_test/fate_test/fate_test_config.yaml deleted file mode 100644 index f589259581..0000000000 --- a/python/fate_test/fate_test/fate_test_config.yaml +++ /dev/null @@ -1,61 +0,0 @@ -# base dir for data upload conf eg, data_base_dir={FATE} -# examples/data/breast_hetero_guest.csv -> $data_base_dir/examples/data/breast_hetero_guest.csv -data_base_dir: path(FATE) - -# directory dedicated to fate_test job file storage, default cache location={FATE}/examples/cache/ -cache_directory: examples/cache/ -# directory stores performance benchmark suites, default location={FATE}/examples/benchmark_performance -performance_template_directory: examples/benchmark_performance/ -# directory stores flow test config, default location={FATE}/examples/flow_test_template/hetero_lr/flow_test_config.yaml -flow_test_config_directory: examples/flow_test_template/hetero_lr/flow_test_config.yaml - -# directory stores testsuite file with min_test data sets to upload, -# default location={FATE}/examples/data/upload_config/min_test_data_testsuite.json -min_test_data_config: examples/data/upload_config/min_test_data_testsuite.json -# directory stores testsuite file with all example data sets to upload, -# default location={FATE}/examples/data/upload_config/all_examples_data_testsuite.json -all_examples_data_config: examples/data/upload_config/all_examples_data_testsuite.json - -# directory where FATE code locates, default installation location={FATE}/fate -# python/federatedml -> $fate_base/python/federatedml -fate_base: path(FATE)/fate - -# whether to delete data in suites after all jobs done -clean_data: true - -# participating parties' id and correponding flow service ip & port information -parties: - guest: [9999] - host: [10000, 9999] - arbiter: [10000] -services: - - flow_services: - - {address: 127.0.0.1:9380, parties: [9999, 10000]} - serving_setting: - address: 127.0.0.1:8059 - - ssh_tunnel: # optional - enable: false - ssh_address: : - ssh_username: - ssh_password: # optional - ssh_priv_key: "~/.ssh/id_rsa" - - -# what is ssh_tunnel? -# to open the ssh tunnel(s) if the remote service -# cannot be accessed directly from the location where the test suite is run! -# -# +---------------------+ -# | ssh address | -# | ssh username | -# | ssh password/ | -# +--------+ | ssh priv_key | +----------------+ -# |local ip+----------ssh tuunel-------------->+remote local ip | -# +--------+ | | +----------------+ -# | | -# request local ip:port +----- as if --------->request remote's local ip:port from remote side -# | | -# | | -# +---------------------+ -# diff --git a/python/fate_test/fate_test/flow_test/__init__.py b/python/fate_test/fate_test/flow_test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate_test/fate_test/flow_test/flow_cli_api.py b/python/fate_test/fate_test/flow_test/flow_cli_api.py deleted file mode 100644 index fe6d17eb11..0000000000 --- a/python/fate_test/fate_test/flow_test/flow_cli_api.py +++ /dev/null @@ -1,668 +0,0 @@ -import json -import os -import sys -import shutil -import time -import subprocess -import numpy as np -from pathlib import Path - -from prettytable import PrettyTable, ORGMODE -from fate_test.flow_test.flow_process import get_dict_from_file, serving_connect - - -class TestModel(object): - def __init__(self, data_base_dir, fate_flow_path, component_name, namespace): - self.conf_path = None - self.dsl_path = None - self.job_id = None - self.model_id = None - self.model_version = None - self.guest_party_id = None - self.host_party_id = None - self.arbiter_party_id = None - self.output_path = None - self.cache_directory = None - - self.data_base_dir = data_base_dir - self.fate_flow_path = fate_flow_path - self.component_name = component_name - - self.python_bin = sys.executable or 'python3' - - self.request_api_info_path = f'./logs/{namespace}/cli_exception.log' - os.makedirs(os.path.dirname(self.request_api_info_path), exist_ok=True) - - def error_log(self, retmsg): - if retmsg is None: - return os.path.abspath(self.request_api_info_path) - with open(self.request_api_info_path, "a") as f: - f.write(retmsg) - - def submit_job(self, stop=True): - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "submit_job", "-d", self.dsl_path, - "-c", self.conf_path], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('job submit: {}'.format(stdout.get('retmsg')) + '\n') - self.job_id = stdout.get("jobId") - self.model_id = stdout.get("data").get("model_info").get("model_id") - self.model_version = stdout.get("data").get("model_info").get("model_version") - if stop: - return - return self.query_status() - except Exception: - return - - def job_api(self, command): - if command == 'stop_job': - self.submit_job() - time.sleep(5) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('job stop: {}'.format(stdout.get('retmsg')) + '\n') - if self.query_job() == "canceled": - return stdout.get('retcode') - except Exception: - return - - elif command == 'job_log_download': - log_file_dir = os.path.join(self.output_path, 'job_{}_log'.format(self.job_id)) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, "-o", - log_file_dir], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('job log: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'data_view_query': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, - "-r", "guest"], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('data view queue: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get("data")) == len(list(get_dict_from_file(self.dsl_path)['components'].keys())) - 1: - return stdout.get('retcode') - except Exception: - return - - elif command == 'clean_job': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('clean job: {}'.format(stdout.get('retmsg')) + '\n') - subp = subprocess.Popen([self.python_bin, - self.fate_flow_path, - "-f", - "component_metrics", - "-j", - self.job_id, - "-r", - "guest", - "-p", - str(self.guest_party_id[0]), - "-cpn", - 'evaluation_0'], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - metric, stderr = subp.communicate() - metric = json.loads(metric.decode("utf-8")) - if not metric.get('data'): - return stdout.get('retcode') - except Exception: - return - - elif command == 'clean_queue': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('clean queue: {}'.format(stdout.get('retmsg')) + '\n') - if not self.query_job(queue=True): - return stdout.get('retcode') - except Exception: - return - - def query_job(self, job_id=None, queue=False): - if job_id is None: - job_id = self.job_id - time.sleep(1) - try: - if not queue: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "query_job", "-j", job_id], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if not stdout.get('retcode'): - return stdout.get("data")[0].get("f_status") - else: - self.error_log('query job: {}'.format(stdout.get('retmsg')) + '\n') - else: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "query_job", "-j", job_id, "-s", - "waiting"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if not stdout.get('retcode'): - return len(stdout.get("data")) - except Exception: - return - - def job_config(self, max_iter): - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "job_config", "-j", self.job_id, "-r", - "guest", "-p", str(self.guest_party_id[0]), "-o", self.output_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('job config: {}'.format(stdout.get('retmsg')) + '\n') - job_conf_path = stdout.get('directory') + '/runtime_conf.json' - job_conf = get_dict_from_file(job_conf_path) - if max_iter == job_conf['component_parameters']['common'][self.component_name]['max_iter']: - return stdout.get('retcode') - - except Exception: - return - - def query_task(self): - try: - subp = subprocess.Popen( - [self.python_bin, self.fate_flow_path, "-f", "query_task", "-j", self.job_id, "-r", "guest", - "-p", str(self.guest_party_id[0]), "-cpn", self.component_name], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('task query: {}'.format(stdout.get('retmsg')) + '\n') - status = stdout.get("data")[0].get("f_status") - if status == "success": - return stdout.get('retcode') - except Exception: - return - - def component_api(self, command, max_iter=None): - component_output_path = os.path.join(self.output_path, 'job_{}_output_data'.format(self.job_id)) - if command == 'component_output_data': - try: - subp = subprocess.Popen( - [self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, "-r", - "guest", "-p", str(self.guest_party_id[0]), "-cpn", self.component_name, "-o", - component_output_path], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component output data: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_output_data_table': - try: - subp = subprocess.Popen( - [self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, "-r", - "guest", "-p", str(self.guest_party_id[0]), "-cpn", self.component_name], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component output data table: {}'.format(stdout.get('retmsg')) + '\n') - table = {'table_name': stdout.get("data")[0].get("table_name"), - 'namespace': stdout.get("data")[0].get("namespace")} - if not self.table_api('table_info', table): - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_output_model': - try: - subp = subprocess.Popen([self.python_bin, - self.fate_flow_path, - "-f", - command, - "-r", - "guest", - "-j", - self.job_id, - "-p", - str(self.guest_party_id[0]), - "-cpn", - self.component_name], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component output model: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_parameters': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, - "-r", "guest", "-p", str(self.guest_party_id[0]), "-cpn", self.component_name], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component parameters: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data', {}).get('ComponentParam', {}).get('max_iter', {}) == max_iter: - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_metrics': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, - "-r", "guest", "-p", str(self.guest_party_id[0]), "-cpn", 'evaluation_0'], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component metrics: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - metrics_file = self.output_path + '{}_metrics.json'.format(self.job_id) - with open(metrics_file, 'w') as fp: - json.dump(stdout.get("data"), fp) - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_metric_all': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", self.job_id, - "-r", "guest", "-p", str(self.guest_party_id[0]), "-cpn", 'evaluation_0'], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component metric all: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - metric_all_file = self.output_path + '{}_metric_all.json'.format(self.job_id) - with open(metric_all_file, 'w') as fp: - json.dump(stdout.get("data"), fp) - return stdout.get('retcode') - except Exception: - return - - elif command == 'component_metric_delete': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-j", - self.job_id, "-r", "guest", "-p", str(self.guest_party_id[0]), "-cpn", - 'evaluation_0'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('component metric delete: {}'.format(stdout.get('retmsg')) + '\n') - subp = subprocess.Popen([self.python_bin, - self.fate_flow_path, - "-f", - "component_metrics", - "-j", - self.job_id, - "-r", - "guest", - "-p", - str(self.guest_party_id[0]), - "-cpn", - 'evaluation_0'], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - metric, stderr = subp.communicate() - metric = json.loads(metric.decode("utf-8")) - if not metric.get('data'): - return stdout.get('retcode') - except Exception: - return - - def table_api(self, command, table_name): - if command == 'table_info': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-t", - table_name['table_name'], "-n", table_name['namespace']], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('table info: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data')['namespace'] == table_name['namespace'] and \ - stdout.get('data')['table_name'] == table_name['table_name']: - return stdout.get('retcode') - except Exception: - return - - elif command == 'table_delete': - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-t", - table_name['table_name'], "-n", table_name['namespace']], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('table delete: {}'.format(stdout.get('retmsg')) + '\n') - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "table_delete", "-t", - table_name['table_name'], "-n", table_name['namespace']], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - return 0 - except Exception: - return - - def data_upload(self, upload_path, table_index=None): - upload_file = get_dict_from_file(upload_path) - upload_file['file'] = str(self.data_base_dir.joinpath(upload_file['file']).resolve()) - upload_file['drop'] = 1 - upload_file['use_local_data'] = 0 - if table_index is not None: - upload_file['table_name'] = f'{upload_file["file"]}_{table_index}' - - upload_path = self.cache_directory + 'upload_file.json' - with open(upload_path, 'w') as fp: - json.dump(upload_file, fp) - - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "upload", "-c", - upload_path, "-drop", "1"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('data upload: {}'.format(stdout.get('retmsg')) + '\n') - return self.query_status(stdout.get("jobId")) - except Exception: - return - - def data_download(self, table_name, output_path): - download_config = { - "table_name": table_name['table_name'], - "namespace": table_name['namespace'], - "output_path": output_path + '{}download.csv'.format(self.job_id) - } - config_file_path = self.cache_directory + 'download_config.json' - with open(config_file_path, 'w') as fp: - json.dump(download_config, fp) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "download", "-c", config_file_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('data download: {}'.format(stdout.get('retmsg')) + '\n') - return self.query_status(stdout.get("jobId")) - except Exception: - return - - def data_upload_history(self, conf_file): - self.data_upload(conf_file, table_index=1) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", "upload_history", "-limit", "2"], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('data upload history: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get('data')) == 2: - return stdout.get('retcode') - except Exception: - return - - def model_api(self, command, remove_path=None, model_path=None, model_load_conf=None, servings=None): - if model_load_conf is not None: - model_load_conf["job_parameters"].update({"model_id": self.model_id, - "model_version": self.model_version}) - - if command == 'load': - model_load_path = self.cache_directory + 'model_load_file.json' - with open(model_load_path, 'w') as fp: - json.dump(model_load_conf, fp) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-c", model_load_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('model load: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'bind': - service_id = "".join([str(i) for i in np.random.randint(9, size=8)]) - model_load_conf.update({"service_id": service_id, "servings": [servings]}) - model_bind_path = self.cache_directory + 'model_load_file.json' - with open(model_bind_path, 'w') as fp: - json.dump(model_load_conf, fp) - try: - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-c", model_bind_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('model bind: {}'.format(stdout.get('retmsg')) + '\n') - else: - return stdout.get('retcode') - except Exception: - return - - elif command == 'import': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "file": model_path, - "force_update": 1, - } - - config_file_path = self.cache_directory + 'model_import.json' - with open(config_file_path, 'w') as fp: - json.dump(config_data, fp) - try: - remove_path = Path(remove_path + self.model_version) - if os.path.isdir(remove_path): - shutil.rmtree(remove_path) - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-c", config_file_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if not stdout.get('retcode') and os.path.isdir(remove_path): - return 0 - else: - self.error_log('model import: {}'.format(stdout.get('retmsg')) + '\n') - except Exception: - return - - elif command == 'export': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0] - } - config_file_path = self.cache_directory + 'model_export.json' - with open(config_file_path, 'w') as fp: - json.dump(config_data, fp) - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-c", config_file_path, "-o", - self.output_path], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('model export: {}'.format(stdout.get('retmsg')) + '\n') - else: - export_model_path = stdout.get('file') - return stdout.get('retcode'), export_model_path - - elif command in ['store', 'restore']: - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0] - } - config_file_path = self.cache_directory + 'model_store.json' - with open(config_file_path, 'w') as fp: - json.dump(config_data, fp) - - subp = subprocess.Popen([self.python_bin, self.fate_flow_path, "-f", command, "-c", config_file_path], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = json.loads(stdout.decode("utf-8")) - if stdout.get('retcode'): - self.error_log('model {}: {}'.format(command, stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - - def query_status(self, job_id=None): - while True: - time.sleep(5) - status = self.query_job(job_id=job_id) - if status and status in ["waiting", "running", "success"]: - if status and status == "success": - return 0 - else: - return - - def set_config(self, guest_party_id, host_party_id, arbiter_party_id, path, component_name): - config = get_dict_from_file(path) - config["initiator"]["party_id"] = guest_party_id[0] - config["role"]["guest"] = guest_party_id - config["role"]["host"] = host_party_id - if "arbiter" in config["role"]: - config["role"]["arbiter"] = arbiter_party_id - self.guest_party_id = guest_party_id - self.host_party_id = host_party_id - self.arbiter_party_id = arbiter_party_id - conf_file_path = self.cache_directory + 'conf_file.json' - with open(conf_file_path, 'w') as fp: - json.dump(config, fp) - self.conf_path = conf_file_path - return config['component_parameters']['common'][component_name]['max_iter'] - - -def judging_state(retcode): - if not retcode and retcode is not None: - return 'success' - else: - return 'failed' - - -def run_test_api(config_json, namespace): - output_path = './output/flow_test_data/' - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - fate_flow_path = config_json['data_base_dir'] / 'fateflow' / 'python' / 'fate_flow' / 'fate_flow_client.py' - if not fate_flow_path.exists(): - raise FileNotFoundError(f'fate_flow not found. filepath: {fate_flow_path}') - test_api = TestModel(config_json['data_base_dir'], str(fate_flow_path), config_json['component_name'], namespace) - test_api.dsl_path = config_json['train_dsl_path'] - test_api.cache_directory = config_json['cache_directory'] - test_api.output_path = str(os.path.abspath(output_path)) + '/' - - conf_path = config_json['train_conf_path'] - guest_party_id = config_json['guest_party_id'] - host_party_id = config_json['host_party_id'] - arbiter_party_id = config_json['arbiter_party_id'] - upload_file_path = config_json['upload_file_path'] - model_file_path = config_json['model_file_path'] - conf_file = get_dict_from_file(upload_file_path) - - serving_connect_bool = serving_connect(config_json['serving_setting']) - remove_path = str(config_json['data_base_dir']).split("python")[ - 0] + '/fateflow/model_local_cache/guest#{}#arbiter-{}#guest-{}#host-{}#model/'.format( - guest_party_id[0], arbiter_party_id[0], guest_party_id[0], host_party_id[0]) - max_iter = test_api.set_config(guest_party_id, host_party_id, arbiter_party_id, conf_path, - config_json['component_name']) - - data = PrettyTable() - data.set_style(ORGMODE) - data.field_names = ['data api name', 'status'] - data.add_row(['data upload', judging_state(test_api.data_upload(upload_file_path))]) - data.add_row(['data download', judging_state(test_api.data_download(conf_file, output_path))]) - data.add_row( - ['data upload history', judging_state(test_api.data_upload_history(upload_file_path))]) - print(data.get_string(title="data api")) - - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = ['table api name', 'status'] - table.add_row(['table info', judging_state(test_api.table_api('table_info', conf_file))]) - table.add_row(['delete table', judging_state(test_api.table_api('table_delete', conf_file))]) - print(table.get_string(title="table api")) - - job = PrettyTable() - job.set_style(ORGMODE) - job.field_names = ['job api name', 'status'] - job.add_row(['job stop', judging_state(test_api.job_api('stop_job'))]) - job.add_row(['job submit', judging_state(test_api.submit_job(stop=False))]) - job.add_row(['job query', judging_state(False if test_api.query_job() == "success" else True)]) - job.add_row(['job data view', judging_state(test_api.job_api('data_view_query'))]) - job.add_row(['job config', judging_state(test_api.job_config(max_iter=max_iter))]) - job.add_row(['job log', judging_state(test_api.job_api('job_log_download'))]) - - task = PrettyTable() - task.set_style(ORGMODE) - task.field_names = ['task api name', 'status'] - task.add_row(['task query', judging_state(test_api.query_task())]) - print(task.get_string(title="task api")) - - component = PrettyTable() - component.set_style(ORGMODE) - component.field_names = ['component api name', 'status'] - component.add_row(['output data', judging_state(test_api.component_api('component_output_data'))]) - component.add_row(['output table', judging_state(test_api.component_api('component_output_data_table'))]) - component.add_row(['output model', judging_state(test_api.component_api('component_output_model'))]) - component.add_row( - ['component parameters', judging_state(test_api.component_api('component_parameters', max_iter=max_iter))]) - component.add_row(['metrics', judging_state(test_api.component_api('component_metrics'))]) - component.add_row(['metrics all', judging_state(test_api.component_api('component_metric_all'))]) - - model = PrettyTable() - model.set_style(ORGMODE) - model.field_names = ['model api name', 'status'] - if not config_json.get('component_is_homo') and serving_connect_bool: - model_load_conf = get_dict_from_file(model_file_path) - model_load_conf["initiator"]["party_id"] = guest_party_id - model_load_conf["role"].update( - {"guest": [guest_party_id], "host": [host_party_id], "arbiter": [arbiter_party_id]}) - model.add_row(['model load', judging_state(test_api.model_api('load', model_load_conf=model_load_conf))]) - model.add_row(['model bind', judging_state( - test_api.model_api('bind', model_load_conf=model_load_conf, servings=config_json['serving_setting']))]) - - status, model_path = test_api.model_api('export') - model.add_row(['model export', judging_state(status)]) - model.add_row(['model import', (judging_state( - test_api.model_api('import', remove_path=remove_path, model_path=model_path)))]) - model.add_row(['model store', (judging_state(test_api.model_api('store')))]) - model.add_row(['model restore', (judging_state(test_api.model_api('restore')))]) - print(model.get_string(title="model api")) - - component.add_row(['metrics delete', judging_state(test_api.component_api('component_metric_delete'))]) - print(component.get_string(title="component api")) - - test_api.submit_job() - test_api.submit_job() - test_api.submit_job() - - job.add_row(['clean job', judging_state(test_api.job_api('clean_job'))]) - job.add_row(['clean queue', judging_state(test_api.job_api('clean_queue'))]) - print(job.get_string(title="job api")) - print('Please check the error content: {}'.format(test_api.error_log(None))) diff --git a/python/fate_test/fate_test/flow_test/flow_process.py b/python/fate_test/fate_test/flow_test/flow_process.py deleted file mode 100644 index 7c2edc2464..0000000000 --- a/python/fate_test/fate_test/flow_test/flow_process.py +++ /dev/null @@ -1,404 +0,0 @@ -import json -import os -import tarfile -import time -import subprocess -from contextlib import closing -from datetime import datetime - -import requests - - -def get_dict_from_file(file_name): - with open(file_name, 'r', encoding='utf-8') as f: - json_info = json.load(f) - return json_info - - -def serving_connect(serving_setting): - subp = subprocess.Popen([f'echo "" | telnet {serving_setting.split(":")[0]} {serving_setting.split(":")[1]}'], - shell=True, stdout=subprocess.PIPE) - stdout, stderr = subp.communicate() - stdout = stdout.decode("utf-8") - return True if f'Connected to {serving_setting.split(":")[0]}' in stdout else False - - -class Base(object): - def __init__(self, data_base_dir, server_url, component_name): - self.config = None - self.dsl = None - self.guest_party_id = None - self.host_party_id = None - self.job_id = None - self.model_id = None - self.model_version = None - - self.data_base_dir = data_base_dir - self.server_url = server_url - self.component_name = component_name - - def set_config(self, guest_party_id, host_party_id, arbiter_party_id, path): - self.config = get_dict_from_file(path) - self.config["initiator"]["party_id"] = guest_party_id[0] - self.config["role"]["guest"] = guest_party_id - self.config["role"]["host"] = host_party_id - if "arbiter" in self.config["role"]: - self.config["role"]["arbiter"] = arbiter_party_id - self.guest_party_id = guest_party_id - self.host_party_id = host_party_id - return self.config - - def set_dsl(self, path): - self.dsl = get_dict_from_file(path) - return self.dsl - - def submit(self): - post_data = {'job_runtime_conf': self.config, 'job_dsl': self.dsl} - print(f"start submit job, data:{post_data}") - response = requests.post("/".join([self.server_url, "job", "submit"]), json=post_data) - if response.status_code == 200 and not response.json().get('retcode'): - self.job_id = response.json().get("jobId") - print(f"submit job success: {response.json()}") - self.model_id = response.json().get("data").get("model_info").get("model_id") - self.model_version = response.json().get("data").get("model_info").get("model_version") - return True - else: - print(f"submit job failed: {response.text}") - return False - - def query_job(self): - post_data = {'job_id': self.job_id} - response = requests.post("/".join([self.server_url, "job", "query"]), json=post_data) - if response.status_code == 200: - if response.json().get("data"): - return response.json().get("data")[0].get("f_status") - return False - - def wait_success(self, timeout=60 * 10): - for i in range(timeout // 10): - time.sleep(10) - status = self.query_job() - print("job {} status is {}".format(self.job_id, status)) - if status and status == "success": - return True - if status and status in ["canceled", "timeout", "failed"]: - return False - return False - - def get_component_output_data(self, output_path=None): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - if not output_path: - output_path = './output/data' - os.makedirs(os.path.dirname(output_path), exist_ok=True) - tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(post_data['job_id'], post_data['component_name'], - post_data['role'], post_data['party_id']) - extract_dir = os.path.join(output_path, tar_file_name.replace('.tar.gz', '')) - print("start get component output dat") - - with closing( - requests.get("/".join([self.server_url, "tracking", "component/output/data/download"]), json=post_data, - stream=True)) as response: - if response.status_code == 200: - try: - download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir) - print(f'get component output path {extract_dir}') - except BaseException: - print(f"get component output data failed") - return False - - def get_output_data_table(self): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - response = requests.post("/".join([self.server_url, "tracking", "component/output/data/table"]), json=post_data) - result = {} - try: - if response.status_code == 200: - result["name"] = response.json().get("data")[0].get("table_name") - result["namespace"] = response.json().get("data")[0].get("namespace") - except Exception as e: - raise RuntimeError(f"output data table error: {response}") from e - return result - - def get_table_info(self, table_name): - post_data = { - "name": table_name['name'], - "namespace": table_name['namespace'] - } - response = requests.post("/".join([self.server_url, "table", "table_info"]), json=post_data) - try: - if response.status_code == 200: - table_count = response.json().get("data").get("count") - else: - raise RuntimeError(f"get table info failed: {response}") - except Exception as e: - raise RuntimeError(f"get table count error: {response}") from e - return table_count - - def get_auc(self): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": "evaluation_0" - } - response = requests.post("/".join([self.server_url, "tracking", "component/metric/all"]), json=post_data) - try: - if response.status_code == 200: - auc = response.json().get("data").get("train").get(self.component_name).get("data")[0][1] - else: - raise RuntimeError(f"get metrics failed: {response}") - except Exception as e: - raise RuntimeError(f"get table count error: {response}") from e - return auc - - -class TrainLRModel(Base): - def get_component_metrics(self, metric_output_path, file=None): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": "evaluation_0" - } - response = requests.post("/".join([self.server_url, "tracking", "component/metric/all"]), json=post_data) - if response.status_code == 200: - if response.json().get("data"): - if not file: - file = metric_output_path.format(self.job_id) - os.makedirs(os.path.dirname(file), exist_ok=True) - with open(file, 'w') as fp: - json.dump(response.json().get("data"), fp) - print(f"save component metrics success, path is:{os.path.abspath(file)}") - else: - print(f"get component metrics:{response.json()}") - return False - - def get_component_output_model(self, model_output_path, file=None): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - print(f"request component output model: {post_data}") - response = requests.post("/".join([self.server_url, "tracking", "component/output/model"]), json=post_data) - if response.status_code == 200: - if response.json().get("data"): - if not file: - file = model_output_path.format(self.job_id) - os.makedirs(os.path.dirname(file), exist_ok=True) - with open(file, 'w') as fp: - json.dump(response.json().get("data"), fp) - print(f"save component output model success, path is:{os.path.abspath(file)}") - else: - print(f"get component output model:{response.json()}") - return False - - -class PredictLRMode(Base): - def set_predict(self, guest_party_id, host_party_id, arbiter_party_id, model_id, model_version, path): - self.set_config(guest_party_id, host_party_id, arbiter_party_id, path) - if self.config["job_parameters"].get("common"): - self.config["job_parameters"]["common"]["model_id"] = model_id - self.config["job_parameters"]["common"]["model_version"] = model_version - else: - self.config["job_parameters"]["model_id"] = model_id - self.config["job_parameters"]["model_version"] = model_version - - -def download_from_request(http_response, tar_file_name, extract_dir): - with open(tar_file_name, 'wb') as fw: - for chunk in http_response.iter_content(1024): - if chunk: - fw.write(chunk) - tar = tarfile.open(tar_file_name, "r:gz") - file_names = tar.getnames() - for file_name in file_names: - tar.extract(file_name, extract_dir) - tar.close() - os.remove(tar_file_name) - - -def train_job(data_base_dir, guest_party_id, host_party_id, arbiter_party_id, train_conf_path, train_dsl_path, - server_url, component_name, metric_output_path, model_output_path, constant_auc): - train = TrainLRModel(data_base_dir, server_url, component_name) - train.set_config(guest_party_id, host_party_id, arbiter_party_id, train_conf_path) - train.set_dsl(train_dsl_path) - status = train.submit() - if status: - is_success = train.wait_success(timeout=600) - if is_success: - train.get_component_metrics(metric_output_path) - train.get_component_output_model(model_output_path) - train.get_component_output_data() - train_auc = train.get_auc() - assert abs(constant_auc - train_auc) <= 1e-4, 'The training result is wrong, auc: {}'.format(train_auc) - train_data_count = train.get_table_info(train.get_output_data_table()) - return train, train_data_count - return False - - -def predict_job(data_base_dir, guest_party_id, host_party_id, arbiter_party_id, predict_conf_path, predict_dsl_path, - model_id, model_version, server_url, component_name): - predict = PredictLRMode(data_base_dir, server_url, component_name) - predict.set_predict(guest_party_id, host_party_id, arbiter_party_id, model_id, model_version, predict_conf_path) - predict.set_dsl(predict_dsl_path) - status = predict.submit() - if status: - is_success = predict.wait_success(timeout=600) - if is_success: - predict.get_component_output_data() - predict_data_count = predict.get_table_info(predict.get_output_data_table()) - return predict, predict_data_count - return False - - -class UtilizeModel: - def __init__(self, model_id, model_version, server_url): - self.model_id = model_id - self.model_version = model_version - self.deployed_model_version = None - self.service_id = None - self.server_url = server_url - - def deploy_model(self): - post_data = { - "model_id": self.model_id, - "model_version": self.model_version - } - response = requests.post("/".join([self.server_url, "model", "deploy"]), json=post_data) - print(f'Request data of deploy model request: {json.dumps(post_data, indent=4)}') - if response.status_code == 200: - resp_data = response.json() - print(f'Response of model deploy request: {json.dumps(resp_data, indent=4)}') - if resp_data.get("retcode", 100) == 0: - self.deployed_model_version = resp_data.get("data", {}).get("model_version") - else: - raise Exception(f"Model {self.model_id} {self.model_version} deploy failed, " - f"details: {resp_data.get('retmsg')}") - else: - raise Exception(f"Request model deploy api failed, status code: {response.status_code}") - - def load_model(self): - post_data = { - "job_id": self.deployed_model_version - } - response = requests.post("/".join([self.server_url, "model", "load"]), json=post_data) - print(f'Request data of load model request: {json.dumps(post_data, indent=4)}') - if response.status_code == 200: - resp_data = response.json() - print(f'Response of load model request: {json.dumps(resp_data, indent=4)}') - if not resp_data.get('retcode'): - return True - raise Exception(f"Load model {self.model_id} {self.deployed_model_version} failed, " - f"details: {resp_data.get('retmsg')}") - raise Exception(f"Request model load api failed, status code: {response.status_code}") - - def bind_model(self): - post_data = { - "job_id": self.deployed_model_version, - "service_id": f"auto_test_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}" - } - response = requests.post("/".join([self.server_url, "model", "bind"]), json=post_data) - print(f'Request data of bind model request: {json.dumps(post_data, indent=4)}') - if response.status_code == 200: - resp_data = response.json() - print(f'Response data of bind model request: {json.dumps(resp_data, indent=4)}') - if not resp_data.get('retcode'): - self.service_id = post_data.get('service_id') - return True - raise Exception(f"Bind model {self.model_id} {self.deployed_model_version} failed, " - f"details: {resp_data.get('retmsg')}") - raise Exception(f"Request model bind api failed, status code: {response.status_code}") - - def online_predict(self, online_serving, phone_num): - serving_url = f"http://{online_serving}/federation/1.0/inference" - post_data = { - "head": { - "serviceId": self.service_id - }, - "body": { - "featureData": { - "phone_num": phone_num, - }, - "sendToRemoteFeatureData": { - "device_type": "imei", - "phone_num": phone_num, - "encrypt_type": "raw" - } - } - } - headers = {"Content-Type": "application/json"} - response = requests.post(serving_url, json=post_data, headers=headers) - print(f"Request data of online predict request: {json.dumps(post_data, indent=4)}") - if response.status_code == 200: - print(f"Online predict successfully, response: {json.dumps(response.json(), indent=4)}") - else: - print(f"Online predict successfully, details: {response.text}") - - -def run_fate_flow_test(config_json): - data_base_dir = config_json['data_base_dir'] - guest_party_id = config_json['guest_party_id'] - host_party_id = config_json['host_party_id'] - arbiter_party_id = config_json['arbiter_party_id'] - train_conf_path = config_json['train_conf_path'] - train_dsl_path = config_json['train_dsl_path'] - server_url = config_json['server_url'] - online_serving = config_json['online_serving'] - constant_auc = config_json['train_auc'] - component_name = config_json['component_name'] - metric_output_path = config_json['metric_output_path'] - model_output_path = config_json['model_output_path'] - serving_connect_bool = serving_connect(config_json['serving_setting']) - phone_num = config_json['phone_num'] - - print('submit train job') - # train - train, train_count = train_job(data_base_dir, guest_party_id, host_party_id, arbiter_party_id, train_conf_path, - train_dsl_path, server_url, component_name, metric_output_path, model_output_path, constant_auc) - if not train: - print('train job run failed') - return False - print('train job success') - - # deploy - print('start deploy model') - utilize = UtilizeModel(train.model_id, train.model_version, server_url) - utilize.deploy_model() - print('deploy model success') - - # predict - predict_conf_path = config_json['predict_conf_path'] - predict_dsl_path = config_json['predict_dsl_path'] - model_id = train.model_id - model_version = utilize.deployed_model_version - print('start submit predict job') - predict, predict_count = predict_job(data_base_dir, guest_party_id, host_party_id, arbiter_party_id, predict_conf_path, - predict_dsl_path, model_id, model_version, server_url, component_name) - if not predict: - print('predict job run failed') - return False - if train_count != predict_count: - print('Loss of forecast data') - return False - print('predict job success') - - if not config_json.get('component_is_homo') and serving_connect_bool: - # load model - utilize.load_model() - - # bind model - utilize.bind_model() - - # online predict - utilize.online_predict(online_serving=online_serving, phone_num=phone_num) diff --git a/python/fate_test/fate_test/flow_test/flow_rest_api.py b/python/fate_test/fate_test/flow_test/flow_rest_api.py deleted file mode 100644 index e7a9c14ab1..0000000000 --- a/python/fate_test/fate_test/flow_test/flow_rest_api.py +++ /dev/null @@ -1,905 +0,0 @@ -import json -import os -import shutil -import time -import numpy as np -from pathlib import Path - -import requests -from contextlib import closing -from prettytable import PrettyTable, ORGMODE -from fate_test.flow_test.flow_process import Base, get_dict_from_file, download_from_request, serving_connect - - -class TestModel(Base): - def __init__(self, data_base_dir, server_url, component_name, namespace): - super().__init__(data_base_dir, server_url, component_name) - self.request_api_info_path = f'./logs/{namespace}/cli_exception.log' - os.makedirs(os.path.dirname(self.request_api_info_path), exist_ok=True) - - def error_log(self, retmsg): - if retmsg is None: - return os.path.abspath(self.request_api_info_path) - with open(self.request_api_info_path, "a") as f: - f.write(retmsg) - - def submit_job(self, stop=True): - post_data = {'job_runtime_conf': self.config, 'job_dsl': self.dsl} - try: - response = requests.post("/".join([self.server_url, "job", "submit"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job submit: {}'.format(response.json().get('retmsg')) + '\n') - self.job_id = response.json().get("jobId") - self.model_id = response.json().get("data").get("model_info").get("model_id") - self.model_version = response.json().get("data").get("model_info").get("model_version") - if stop: - return - return self.query_status(self.job_id) - except Exception: - return - - def job_dsl_generate(self): - post_data = { - 'train_dsl': '{"components": {"data_transform_0": {"module": "DataTransform", "input": {"data": {"data": []}},' - '"output": {"data": ["train"], "model": ["data_transform"]}}}}', - 'cpn_str': 'data_transform_0'} - try: - response = requests.post("/".join([self.server_url, "job", "dsl/generate"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job dsl generate: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data')['components']['data_transform_0']['input']['model'][ - 0] == 'pipeline.data_transform_0.data_transform': - return response.json().get('retcode') - except Exception: - return - - def job_api(self, command, output_path=None): - post_data = {'job_id': self.job_id, "role": "guest"} - if command == 'rerun': - try: - response = requests.post("/".join([self.server_url, "job", command]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job rerun: {}'.format(response.json().get('retmsg')) + '\n') - return self.query_status(self.job_id) - except Exception: - return - - elif command == 'stop': - self.submit_job() - time.sleep(5) - try: - response = requests.post("/".join([self.server_url, "job", command]), json={'job_id': self.job_id}) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job stop: {}'.format(response.json().get('retmsg')) + '\n') - if self.query_job() == "canceled": - return response.json().get('retcode') - except Exception: - return - - elif command == 'data/view/query': - try: - response = requests.post("/".join([self.server_url, "job", command]), json=post_data) - if response.json().get('retcode'): - self.error_log('data view query: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get("data")) == len(self.dsl['components'].keys()) - 1: - return response.json().get('retcode') - except Exception: - return - - elif command == 'list/job': - post_data = {'limit': 3} - try: - response = requests.post("/".join([self.server_url, "job", "list/job"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job list: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get('data', {}).get('jobs', [])) == post_data["limit"]: - return response.json().get('retcode') - except Exception: - return - - elif command == 'log/download': - post_data = {'job_id': self.job_id} - tar_file_name = 'job_{}_log.tar.gz'.format(post_data['job_id']) - extract_dir = os.path.join(output_path, tar_file_name.replace('.tar.gz', '')) - with closing(requests.post("/".join([self.server_url, "job", command]), json=post_data, stream=True)) as response: - if response.status_code == 200: - try: - download_from_request(http_response=response, tar_file_name=tar_file_name, - extract_dir=extract_dir) - return 0 - except Exception as e: - self.error_log('job log: {}'.format(e) + '\n') - return - - elif command == 'clean/queue': - try: - response = requests.post("/".join([self.server_url, "job", command])) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('clean queue: {}'.format(response.json().get('retmsg')) + '\n') - if not self.query_job(queue=True): - return response.json().get('retcode') - except Exception: - return - - def query_job(self, job_id=None, queue=False): - if job_id is None: - job_id = self.job_id - time.sleep(1) - try: - if not queue: - response = requests.post("/".join([self.server_url, "job", "query"]), json={'job_id': job_id}) - if response.status_code == 200 and response.json().get("data"): - status = response.json().get("data")[0].get("f_status") - return status - else: - self.error_log('query job: {}'.format(response.json().get('retmsg')) + '\n') - else: - response = requests.post("/".join([self.server_url, "job", "query"]), json={'status': 'waiting'}) - if response.status_code == 200 and response.json().get("data"): - return len(response.json().get("data")) - - except Exception: - return - - def job_config(self, max_iter, output_path): - post_data = { - 'job_id': self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "output_path": output_path - } - try: - response = requests.post("/".join([self.server_url, "job", "config"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('job config: {}'.format(response.json().get('retmsg')) + '\n') - job_conf = response.json().get('data')['runtime_conf'] - if max_iter == job_conf['component_parameters']['common'][self.component_name]['max_iter']: - return response.json().get('retcode') - - except Exception: - return - - def query_task(self): - post_data = { - 'job_id': self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - try: - response = requests.post("/".join([self.server_url, "job", "task/query"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('task query: {}'.format(response.json().get('retmsg')) + '\n') - status = response.json().get("data")[0].get("f_status") - if status == "success": - return response.json().get('retcode') - except Exception: - return - - def list_task(self): - post_data = {'limit': 3} - try: - response = requests.post("/".join([self.server_url, "job", "list/task"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('list task: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get('data', {}).get('tasks', [])) == post_data["limit"]: - return response.json().get('retcode') - except Exception: - return - - def component_api(self, command, output_path=None, max_iter=None): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - if command == 'output/data': - tar_file_name = 'job_{}_{}_output_data.tar.gz'.format(post_data['job_id'], post_data['component_name']) - extract_dir = os.path.join(output_path, tar_file_name.replace('.tar.gz', '')) - with closing(requests.get("/".join([self.server_url, "tracking", "component/output/data/download"]), - json=post_data, stream=True)) as response: - if response.status_code == 200: - try: - download_from_request(http_response=response, tar_file_name=tar_file_name, - extract_dir=extract_dir) - return 0 - except Exception as e: - self.error_log('component output data: {}'.format(e) + '\n') - return - - elif command == 'output/data/table': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/output/data/table"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log( - 'component output data table: {}'.format(response.json().get('retmsg')) + '\n') - table = {'table_name': response.json().get("data")[0].get("table_name"), - 'namespace': response.json().get("data")[0].get("namespace")} - if not self.table_api('table_info', table): - return response.json().get('retcode') - except Exception: - return - - elif command == 'output/model': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/output/model"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component output model: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get("data"): - return response.json().get('retcode') - except Exception: - return - - elif command == 'parameters': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/parameters"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component parameters: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data', {}).get('ComponentParam', {}).get('max_iter', {}) == max_iter: - return response.json().get('retcode') - except Exception: - return - - elif command == 'summary/download': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/summary/download"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log( - 'component summary download: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get("data"): - file = output_path + '{}_summary.json'.format(self.job_id) - os.makedirs(os.path.dirname(file), exist_ok=True) - with open(file, 'w') as fp: - json.dump(response.json().get("data"), fp) - return response.json().get('retcode') - except Exception: - return - - def component_metric(self, command, output_path=None): - post_data = { - "job_id": self.job_id, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": 'evaluation_0' - } - if command == 'metrics': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/metrics"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component metrics: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get("data"): - file = output_path + '{}_metrics.json'.format(self.job_id) - os.makedirs(os.path.dirname(file), exist_ok=True) - with open(file, 'w') as fp: - json.dump(response.json().get("data"), fp) - return response.json().get('retcode') - except Exception: - return - - elif command == 'metric/all': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/metric/all"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component metric all: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get("data"): - file = output_path + '{}_metric_all.json'.format(self.job_id) - os.makedirs(os.path.dirname(file), exist_ok=True) - with open(file, 'w') as fp: - json.dump(response.json().get("data"), fp) - return response.json().get('retcode') - except Exception: - return - - elif command == 'metric/delete': - try: - response = requests.post("/".join([self.server_url, "tracking", "component/metric/delete"]), - json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component metric delete: {}'.format(response.json().get('retmsg')) + '\n') - response = requests.post("/".join([self.server_url, "tracking", "component/metrics"]), - json=post_data) - if response.status_code == 200: - if not response.json().get("data"): - return response.json().get('retcode') - except Exception: - return - - def component_list(self): - post_data = {'job_id': self.job_id} - try: - response = requests.post("/".join([self.server_url, "tracking", "component/list"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('component list: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get('data')['components']) == len(list(self.dsl['components'].keys())): - return response.json().get('retcode') - except Exception: - raise - - def table_api(self, command, table_name): - post_data = { - "table_name": table_name['table_name'], - "namespace": table_name['namespace'] - } - if command == 'table/info': - try: - response = requests.post("/".join([self.server_url, "table", "table_info"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('table info: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data')['namespace'] == table_name['namespace'] and \ - response.json().get('data')['table_name'] == table_name['table_name']: - return response.json().get('retcode') - - except Exception: - return - - elif command == 'table/delete': - try: - response = requests.post("/".join([self.server_url, "table", "delete"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('table delete: {}'.format(response.json().get('retmsg')) + '\n') - response = requests.post("/".join([self.server_url, "table", "delete"]), json=post_data) - if response.status_code == 200 and response.json().get('retcode'): - return 0 - except Exception: - return - - def data_upload(self, post_data, table_index=None): - post_data['file'] = str(self.data_base_dir.joinpath(post_data['file']).resolve()) - post_data['drop'] = 1 - post_data['use_local_data'] = 0 - if table_index is not None: - post_data['table_name'] = f'{post_data["file"]}_{table_index}' - - try: - response = requests.post("/".join([self.server_url, "data", "upload"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('data upload: {}'.format(response.json().get('retmsg')) + '\n') - return self.query_status(response.json().get("jobId")) - except Exception: - return - - def data_download(self, table_name, output_path): - post_data = { - "table_name": table_name['table_name'], - "namespace": table_name['namespace'], - "output_path": output_path + '{}download.csv'.format(self.job_id) - } - try: - response = requests.post("/".join([self.server_url, "data", "download"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('data download: {}'.format(response.json().get('retmsg')) + '\n') - return self.query_status(response.json().get("jobId")) - except Exception: - return - - def data_upload_history(self, conf_file): - self.data_upload(conf_file, table_index=1) - post_data = {"limit": 2} - try: - response = requests.post("/".join([self.server_url, "data", "upload/history"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('data upload history: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get('data')) == post_data["limit"]: - return response.json().get('retcode') - except Exception: - return - - def tag_api(self, command, tag_name=None, new_tag_name=None): - post_data = { - "tag_name": tag_name - } - if command == 'tag/retrieve': - try: - response = requests.post("/".join([self.server_url, "model", "tag/retrieve"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('tag retrieve: {}'.format(response.json().get('retmsg')) + '\n') - if not response.json().get('retcode'): - return response.json().get('data')['tags'][0]['name'] - except Exception: - return - - elif command == 'tag/create': - try: - response = requests.post("/".join([self.server_url, "model", "tag/create"]), json=post_data) - if response.status_code == 200: - self.error_log('tag create: {}'.format(response.json().get('retmsg')) + '\n') - if self.tag_api('tag/retrieve', tag_name=tag_name) == tag_name: - return 0 - except Exception: - return - - elif command == 'tag/destroy': - try: - response = requests.post("/".join([self.server_url, "model", "tag/destroy"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('tag destroy: {}'.format(response.json().get('retmsg')) + '\n') - if not self.tag_api('tag/retrieve', tag_name=tag_name): - return 0 - except Exception: - return - - elif command == 'tag/update': - post_data = { - "tag_name": tag_name, - "new_tag_name": new_tag_name - } - try: - response = requests.post("/".join([self.server_url, "model", "tag/update"]), json=post_data) - if response.status_code == 200: - self.error_log('tag update: {}'.format(response.json().get('retmsg')) + '\n') - if self.tag_api('tag/retrieve', tag_name=new_tag_name) == new_tag_name: - return 0 - except Exception: - return - - elif command == 'tag/list': - post_data = {"limit": 1} - try: - response = requests.post("/".join([self.server_url, "model", "tag/list"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('tag list: {}'.format(response.json().get('retmsg')) + '\n') - if len(response.json().get('data')['tags']) == post_data['limit']: - return response.json().get('retcode') - except Exception: - return - - def model_api( - self, - command, - output_path=None, - remove_path=None, - model_path=None, - homo_deploy_path=None, - homo_deploy_kube_config_path=None, - arbiter_party_id=None, - tag_name=None, - model_load_conf=None, - servings=None): - if model_load_conf is not None: - model_load_conf["job_parameters"].update({"model_id": self.model_id, - "model_version": self.model_version}) - if command == 'model/load': - try: - response = requests.post("/".join([self.server_url, "model", "load"]), json=model_load_conf) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model load: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get('retcode') - except Exception: - return - - elif command == 'model/bind': - service_id = "".join([str(i) for i in np.random.randint(9, size=8)]) - post_data = model_load_conf.update({"service_id": service_id, "servings": [servings]}) - try: - response = requests.post("/".join([self.server_url, "model", "bind"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model bind: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get('retcode') - except Exception: - return - - elif command == 'model/import': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "file": model_path, - "force_update": 1, - } - - try: - remove_path = Path(remove_path + self.model_version) - if os.path.exists(model_path): - files = {'file': open(model_path, 'rb')} - else: - return - if os.path.isdir(remove_path): - shutil.rmtree(remove_path) - response = requests.post("/".join([self.server_url, "model", "import"]), data=config_data, files=files) - if response.status_code == 200: - if os.path.isdir(remove_path): - return 0 - except Exception: - return - - elif command == 'model/export': - post_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - } - tar_file_name = '{}_{}_model_export.zip'.format(post_data['model_id'], post_data['model_version']) - archive_file_path = os.path.join(output_path, tar_file_name) - with closing(requests.get("/".join([self.server_url, "model", "export"]), json=post_data, - stream=True)) as response: - if response.status_code == 200: - try: - with open(archive_file_path, 'wb') as fw: - for chunk in response.iter_content(1024): - if chunk: - fw.write(chunk) - except Exception: - return - return 0, archive_file_path - - elif command == 'model/migrate': - post_data = { - "job_parameters": { - "federated_mode": "MULTIPLE" - }, - "migrate_initiator": { - "role": "guest", - "party_id": self.guest_party_id[0] - }, - "role": { - "guest": self.guest_party_id, - "arbiter": arbiter_party_id, - "host": self.host_party_id - }, - "migrate_role": { - "guest": self.guest_party_id, - "arbiter": arbiter_party_id, - "host": self.host_party_id - }, - "execute_party": { - "guest": self.guest_party_id, - "arbiter": arbiter_party_id, - "host": self.host_party_id - }, - "model_id": self.model_id, - "model_version": self.model_version, - "unify_model_version": self.job_id + '_01' - } - try: - response = requests.post("/".join([self.server_url, "model", "migrate"]), json=post_data) - if response.status_code == 200: - self.error_log('model migrate: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get("retcode") - except Exception: - return - - elif command == 'model/homo/convert': - post_data = { - 'model_id': self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - } - try: - response = requests.post("/".join([self.server_url, "model", "homo/convert"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model homo convert: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get("retcode") - except Exception: - return - - elif command == 'model/homo/deploy': - job_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - config_data = get_dict_from_file(homo_deploy_path) - config_data.update(job_data) - if homo_deploy_kube_config_path: - with open(homo_deploy_kube_config_path, 'r') as fp: - config_data['deployment_parameters']['config_file_content'] = fp.read() - config_data['deployment_parameters'].pop('config_file', None) - try: - response = requests.post("/".join([self.server_url, "model", "homo/deploy"]), json=config_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model homo deploy: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get("retcode") - except Exception: - return - - elif command == 'model_tag/create': - post_data = { - "job_id": self.job_id, - "tag_name": tag_name - } - try: - response = requests.post("/".join([self.server_url, "model", "model_tag/create"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model tag create: {}'.format(response.json().get('retmsg')) + '\n') - if self.model_api('model_tag/retrieve')[0].get('name') == post_data['tag_name']: - return 0 - except Exception: - return - - elif command == 'model_tag/remove': - post_data = { - "job_id": self.job_id, - "tag_name": tag_name - } - try: - response = requests.post("/".join([self.server_url, "model", "model_tag/remove"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model tag remove: {}'.format(response.json().get('retmsg')) + '\n') - if not len(self.model_api('model_tag/retrieve')): - return 0 - except Exception: - return - - elif command == 'model_tag/retrieve': - post_data = { - "job_id": self.job_id - } - try: - response = requests.post("/".join([self.server_url, "model", "model_tag/retrieve"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model tag retrieve: {}'.format(response.json().get('retmsg')) + '\n') - return response.json().get('data')['tags'] - except Exception: - return - - elif command == 'model/deploy': - post_data = { - "model_id": self.model_id, - "model_version": self.model_version - } - try: - response = requests.post("/".join([self.server_url, "model", "deploy"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model deploy: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data')['model_id'] == self.model_id and \ - response.json().get('data')['model_version'] != self.model_version: - self.model_id = response.json().get('data')['model_id'] - self.model_version = response.json().get('data')['model_version'] - self.job_id = response.json().get('data')['model_version'] - return response.json().get('retcode') - except Exception: - return - - elif command == 'model/conf': - post_data = { - "model_id": self.model_id, - "model_version": self.model_version - } - try: - response = requests.post("/".join([self.server_url, "model", "get/predict/conf"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model conf: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data'): - if response.json().get('data')['job_parameters']['common']['model_id'] == post_data['model_id']\ - and response.json().get('data')['job_parameters']['common']['model_version'] == \ - post_data['model_version'] and response.json().get('data')['initiator']['party_id'] == \ - self.guest_party_id[0] and response.json().get('data')['initiator']['role'] == 'guest': - return response.json().get('retcode') - - except Exception: - return - - elif command == 'model/dsl': - post_data = { - "model_id": self.model_id, - "model_version": self.model_version - } - try: - response = requests.post("/".join([self.server_url, "model", "get/predict/dsl"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model dsl: {}'.format(response.json().get('retmsg')) + '\n') - model_dsl_cpn = list(response.json().get('data')['components'].keys()) - train_dsl_cpn = list(self.dsl['components'].keys()) - if len([k for k in model_dsl_cpn if k in train_dsl_cpn]) == len(train_dsl_cpn): - return response.json().get('retcode') - except Exception: - return - - elif command == 'model/query': - post_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0] - } - try: - response = requests.post("/".join([self.server_url, "model", "query"]), json=post_data) - if response.status_code == 200: - if response.json().get('retcode'): - self.error_log('model query: {}'.format(response.json().get('retmsg')) + '\n') - if response.json().get('data')[0].get('f_model_id') == post_data['model_id'] and \ - response.json().get('data')[0].get('f_model_version') == post_data['model_version'] and \ - response.json().get('data')[0].get('f_role') == post_data['role'] and \ - response.json().get('data')[0].get('f_party_id') == str(post_data['party_id']): - return response.json().get('retcode') - except Exception: - return - - def query_status(self, job_id): - while True: - time.sleep(5) - status = self.query_job(job_id=job_id) - if status and status in ["waiting", "running", "success"]: - if status and status == "success": - return 0 - else: - return - - -def judging_state(retcode): - if not retcode and retcode is not None: - return 'success' - else: - return 'failed' - - -def run_test_api(config_json, namespace): - output_path = './output/flow_test_data/' - os.makedirs(os.path.dirname(output_path), exist_ok=True) - output_path = str(os.path.abspath(output_path)) + '/' - guest_party_id = config_json['guest_party_id'] - host_party_id = config_json['host_party_id'] - arbiter_party_id = config_json['arbiter_party_id'] - train_conf_path = config_json['train_conf_path'] - train_dsl_path = config_json['train_dsl_path'] - upload_file_path = config_json['upload_file_path'] - model_file_path = config_json['model_file_path'] - remove_path = str(config_json['data_base_dir']).split("python")[ - 0] + '/fateflow/model_local_cache/guest#{}#arbiter-{}#guest-{}#host-{}#model/'.format( - guest_party_id[0], arbiter_party_id[0], guest_party_id[0], host_party_id[0]) - - serving_connect_bool = serving_connect(config_json['serving_setting']) - test_api = TestModel(config_json['data_base_dir'], config_json['server_url'], - component_name=config_json['component_name'], namespace=namespace) - job_conf = test_api.set_config(guest_party_id, host_party_id, arbiter_party_id, train_conf_path) - max_iter = job_conf['component_parameters']['common'][config_json['component_name']]['max_iter'] - test_api.set_dsl(train_dsl_path) - conf_file = get_dict_from_file(upload_file_path) - - data = PrettyTable() - data.set_style(ORGMODE) - data.field_names = ['data api name', 'status'] - data.add_row(['data upload', judging_state(test_api.data_upload(conf_file))]) - data.add_row(['data download', judging_state(test_api.data_download(conf_file, output_path))]) - data.add_row(['data upload history', judging_state(test_api.data_upload_history(conf_file))]) - print(data.get_string(title="data api")) - - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = ['table api name', 'status'] - table.add_row(['table info', judging_state(test_api.table_api('table/info', conf_file))]) - table.add_row(['delete table', judging_state(test_api.table_api('table/delete', conf_file))]) - print(table.get_string(title="table api")) - - job = PrettyTable() - job.set_style(ORGMODE) - job.field_names = ['job api name', 'status'] - job.add_row(['job stop', judging_state(test_api.job_api('stop'))]) - job.add_row(['job rerun', judging_state(test_api.job_api('rerun'))]) - job.add_row(['job submit', judging_state(test_api.submit_job(stop=False))]) - job.add_row(['job query', judging_state(False if test_api.query_job() == "success" else True)]) - job.add_row(['job data view', judging_state(test_api.job_api('data/view/query'))]) - job.add_row(['job list', judging_state(test_api.job_api('list/job'))]) - job.add_row(['job config', judging_state(test_api.job_config(max_iter=max_iter, output_path=output_path))]) - job.add_row(['job log', judging_state(test_api.job_api('log/download', output_path))]) - job.add_row(['job dsl generate', judging_state(test_api.job_dsl_generate())]) - print(job.get_string(title="job api")) - - task = PrettyTable() - task.set_style(ORGMODE) - task.field_names = ['task api name', 'status'] - task.add_row(['task list', judging_state(test_api.list_task())]) - task.add_row(['task query', judging_state(test_api.query_task())]) - print(task.get_string(title="task api")) - - tag = PrettyTable() - tag.set_style(ORGMODE) - tag.field_names = ['tag api name', 'status'] - tag.add_row(['create tag', judging_state(test_api.tag_api('tag/create', 'create_job_tag'))]) - tag.add_row(['update tag', judging_state(test_api.tag_api('tag/update', 'create_job_tag', 'update_job_tag'))]) - tag.add_row(['list tag', judging_state(test_api.tag_api('tag/list'))]) - tag.add_row( - ['retrieve tag', judging_state(not test_api.tag_api('tag/retrieve', 'update_job_tag') == 'update_job_tag')]) - tag.add_row(['destroy tag', judging_state(test_api.tag_api('tag/destroy', 'update_job_tag'))]) - print(tag.get_string(title="tag api")) - - component = PrettyTable() - component.set_style(ORGMODE) - component.field_names = ['component api name', 'status'] - component.add_row(['output data', judging_state(test_api.component_api('output/data', output_path=output_path))]) - component.add_row(['output table', judging_state(test_api.component_api('output/data/table'))]) - component.add_row(['output model', judging_state(test_api.component_api('output/model'))]) - component.add_row(['component parameters', judging_state(test_api.component_api('parameters', max_iter=max_iter))]) - component.add_row( - ['component summary', judging_state(test_api.component_api('summary/download', output_path=output_path))]) - component.add_row(['component list', judging_state(test_api.component_list())]) - component.add_row(['metrics', judging_state( - test_api.component_metric('metrics', output_path=output_path))]) - component.add_row(['metrics all', judging_state( - test_api.component_metric('metric/all', output_path=output_path))]) - - model = PrettyTable() - model.set_style(ORGMODE) - model.field_names = ['model api name', 'status'] - if config_json.get('component_is_homo'): - homo_deploy_path = config_json.get('homo_deploy_path') - homo_deploy_kube_config_path = config_json.get('homo_deploy_kube_config_path') - model.add_row(['model homo convert', judging_state(test_api.model_api('model/homo/convert'))]) - model.add_row(['model homo deploy', - judging_state(test_api.model_api('model/homo/deploy', - homo_deploy_path=homo_deploy_path, - homo_deploy_kube_config_path=homo_deploy_kube_config_path))]) - if not config_json.get('component_is_homo') and serving_connect_bool: - model_load_conf = get_dict_from_file(model_file_path) - model_load_conf["initiator"]["party_id"] = guest_party_id - model_load_conf["role"].update( - {"guest": [guest_party_id], "host": [host_party_id], "arbiter": [arbiter_party_id]}) - model.add_row(['model load', judging_state(test_api.model_api('model/load', model_load_conf=model_load_conf))]) - model.add_row(['model bind', judging_state(test_api.model_api('model/bind', model_load_conf=model_load_conf, - servings=config_json['serving_setting']))]) - status, model_path = test_api.model_api('model/export', output_path=output_path) - model.add_row(['model export', judging_state(status)]) - model.add_row(['model import', (judging_state( - test_api.model_api('model/import', remove_path=remove_path, model_path=model_path)))]) - model.add_row( - ['model_tag create', judging_state(test_api.model_api('model_tag/create', tag_name='model_tag_create'))]) - model.add_row( - ['model_tag remove', judging_state(test_api.model_api('model_tag/remove', tag_name='model_tag_create'))]) - model.add_row(['model_tag retrieve', judging_state(len(test_api.model_api('model_tag/retrieve')))]) - if serving_connect_bool: - model.add_row( - ['model migrate', judging_state(test_api.model_api('model/migrate', arbiter_party_id=arbiter_party_id))]) - model.add_row(['model query', judging_state(test_api.model_api('model/query'))]) - model.add_row(['model deploy', judging_state(test_api.model_api('model/deploy'))]) - model.add_row(['model conf', judging_state(test_api.model_api('model/conf'))]) - model.add_row(['model dsl', judging_state(test_api.model_api('model/dsl'))]) - print(model.get_string(title="model api")) - component.add_row(['metrics delete', judging_state( - test_api.component_metric('metric/delete', output_path=output_path))]) - print(component.get_string(title="component api")) - - queue = PrettyTable() - queue.set_style(ORGMODE) - queue.field_names = ['api name', 'status'] - test_api.submit_job() - test_api.submit_job() - test_api.submit_job() - queue.add_row(['clean/queue', judging_state(test_api.job_api('clean/queue'))]) - print(queue.get_string(title="queue job")) - print('Please check the error content: {}'.format(test_api.error_log(None))) diff --git a/python/fate_test/fate_test/flow_test/flow_sdk_api.py b/python/fate_test/fate_test/flow_test/flow_sdk_api.py deleted file mode 100644 index 243c35d93f..0000000000 --- a/python/fate_test/fate_test/flow_test/flow_sdk_api.py +++ /dev/null @@ -1,801 +0,0 @@ -import json -import os -import shutil -import time -import numpy as np -from pathlib import Path - -from flow_sdk.client import FlowClient -from prettytable import PrettyTable, ORGMODE -from fate_test.flow_test.flow_process import get_dict_from_file, serving_connect - - -class TestModel(object): - def __init__(self, data_base_dir, server_url, component_name, namespace): - self.conf_path = None - self.dsl_path = None - self.job_id = None - self.model_id = None - self.model_version = None - self.guest_party_id = None - self.host_party_id = None - self.arbiter_party_id = None - self.output_path = None - self.cache_directory = None - - self.data_base_dir = data_base_dir - self.component_name = component_name - self.client = FlowClient(server_url.split(':')[0], server_url.split(':')[1].split('/')[0], - server_url.split(':')[1].split('/')[1]) - self.request_api_info_path = f'./logs/{namespace}/sdk_exception.log' - os.makedirs(os.path.dirname(self.request_api_info_path), exist_ok=True) - - def error_log(self, retmsg): - if retmsg is None: - return os.path.abspath(self.request_api_info_path) - with open(self.request_api_info_path, "a") as f: - f.write(retmsg) - - def submit_job(self, stop=True): - try: - stdout = self.client.job.submit(config_data=get_dict_from_file(self.conf_path), - dsl_data=get_dict_from_file(self.dsl_path)) - if stdout.get('retcode'): - self.error_log('job submit: {}'.format(stdout.get('retmsg')) + '\n') - self.job_id = stdout.get("jobId") - self.model_id = stdout.get("data").get("model_info").get("model_id") - self.model_version = stdout.get("data").get("model_info").get("model_version") - if stop: - return - return self.query_status() - except Exception: - return - - def job_dsl_generate(self): - train_dsl = {"components": {"data_transform_0": {"module": "DataTransform", "input": {"data": {"data": []}}, - "output": {"data": ["train"], "model": ["data_transform"]}}}} - train_dsl_path = self.cache_directory + 'generate_dsl_file.json' - with open(train_dsl_path, 'w') as fp: - json.dump(train_dsl, fp) - try: - stdout = self.client.job.generate_dsl(train_dsl=get_dict_from_file(train_dsl_path), - cpn=['data_transform_0']) - if stdout.get('retcode'): - self.error_log('job dsl generate: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data')['components']['data_transform_0']['input']['model'][ - 0] == 'pipeline.data_transform_0.data_transform': - return stdout.get('retcode') - except Exception: - return - - def job_api(self, command): - if command == 'stop': - self.submit_job() - time.sleep(5) - try: - stdout = self.client.job.stop(job_id=self.job_id) - if stdout.get('retcode'): - self.error_log('job stop: {}'.format(stdout.get('retmsg')) + '\n') - if self.query_job() == "canceled": - return stdout.get('retcode') - except Exception: - return - - elif command == 'list/job': - try: - stdout = self.client.job.list(limit=3) - if stdout.get('retcode'): - self.error_log('job list: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get('data', {}).get('jobs', [])) == 3: - return stdout.get('retcode') - except Exception: - return - - elif command == 'view': - try: - stdout = self.client.job.view(job_id=self.job_id, role="guest") - if stdout.get('retcode'): - self.error_log('job view: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get("data")) == len(list(get_dict_from_file(self.dsl_path)['components'].keys())) - 1: - return stdout.get('retcode') - except Exception: - return - - elif command == 'log': - log_file_dir = os.path.join(self.output_path, 'job_{}_log'.format(self.job_id)) - try: - stdout = self.client.job.log(job_id=self.job_id, output_path=log_file_dir) - if stdout.get('retcode'): - self.error_log('job log: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'clean/queue': - try: - stdout = self.client.queue.clean() - if stdout.get('retcode'): - self.error_log('clean queue: {}'.format(stdout.get('retmsg')) + '\n') - if not self.query_job(queue=True): - return stdout.get('retcode') - except Exception: - return - - def query_job(self, job_id=None, queue=False): - if job_id is None: - job_id = self.job_id - time.sleep(1) - try: - if not queue: - stdout = self.client.job.query(job_id=job_id) - if not stdout.get('retcode'): - return stdout.get("data")[0].get("f_status") - else: - self.error_log('query job: {}'.format(stdout.get('retmsg')) + '\n') - else: - stdout = self.client.job.query(job_id=job_id, status='waiting') - if not stdout.get('retcode'): - return len(stdout.get("data")) - except Exception: - return - - def job_config(self, max_iter): - try: - stdout = self.client.job.config(job_id=self.job_id, role="guest", party_id=self.guest_party_id[0], - output_path=self.output_path) - if stdout.get('retcode'): - self.error_log('job config: {}'.format(stdout.get('retmsg')) + '\n') - job_conf_path = stdout.get('directory') + '/runtime_conf.json' - job_conf = get_dict_from_file(job_conf_path) - if max_iter == job_conf['component_parameters']['common'][self.component_name]['max_iter']: - return stdout.get('retcode') - - except Exception: - return - - def query_task(self): - try: - stdout = self.client.task.query(job_id=self.job_id, role="guest", party_id=self.guest_party_id[0], - component_name=self.component_name) - if stdout.get('retcode'): - self.error_log('task query: {}'.format(stdout.get('retmsg')) + '\n') - status = stdout.get("data")[0].get("f_status") - if status == "success": - return stdout.get('retcode') - except Exception: - return - - def list_task(self): - try: - stdout = self.client.task.list(limit=3) - if stdout.get('retcode'): - self.error_log('list task: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get('data', {}).get('tasks', [])) == 3: - return stdout.get('retcode') - except Exception: - return - - def component_api(self, command, max_iter=None): - component_output_path = os.path.join(self.output_path, 'job_{}_output_data'.format(self.job_id)) - if command == 'output/data': - try: - stdout = self.client.component.output_data(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name=self.component_name, - output_path=component_output_path) - if stdout.get('retcode'): - self.error_log('component output data: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'output/data/table': - try: - stdout = self.client.component.output_data_table(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name=self.component_name) - if stdout.get('retcode'): - self.error_log('component output data table: {}'.format(stdout.get('retmsg')) + '\n') - table = {'table_name': stdout.get("data")[0].get("table_name"), - 'namespace': stdout.get("data")[0].get("namespace")} - if not self.table_api('table_info', table): - return stdout.get('retcode') - except Exception: - return - - elif command == 'output/model': - try: - stdout = self.client.component.output_model(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name=self.component_name) - if stdout.get('retcode'): - self.error_log('component output model: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - return stdout.get('retcode') - except Exception: - return - - elif command == 'parameters': - try: - stdout = self.client.component.parameters(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name=self.component_name) - if stdout.get('retcode'): - self.error_log('component parameters: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data', {}).get('ComponentParam', {}).get('max_iter', {}) == max_iter: - return stdout.get('retcode') - except Exception: - return - - elif command == 'summary': - try: - stdout = self.client.component.get_summary(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name=self.component_name) - if stdout.get('retcode'): - self.error_log('component summary download: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - summary_file = self.output_path + '{}_summary.json'.format(self.job_id) - with open(summary_file, 'w') as fp: - json.dump(stdout.get("data"), fp) - return stdout.get('retcode') - except Exception: - return - - elif command == 'metrics': - try: - stdout = self.client.component.metrics(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name='evaluation_0') - if stdout.get('retcode'): - self.error_log('component metrics: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - metrics_file = self.output_path + '{}_metrics.json'.format(self.job_id) - with open(metrics_file, 'w') as fp: - json.dump(stdout.get("data"), fp) - return stdout.get('retcode') - except Exception: - return - - elif command == 'metric/all': - try: - stdout = self.client.component.metric_all(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name='evaluation_0') - if stdout.get('retcode'): - self.error_log('component metric all: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get("data"): - metric_all_file = self.output_path + '{}_metric_all.json'.format(self.job_id) - with open(metric_all_file, 'w') as fp: - json.dump(stdout.get("data"), fp) - return stdout.get('retcode') - except Exception: - return - - elif command == 'metric/delete': - try: - stdout = self.client.component.metric_delete(job_id=self.job_id, date=str(time.strftime("%Y%m%d"))) - if stdout.get('retcode'): - self.error_log('component metric delete: {}'.format(stdout.get('retmsg')) + '\n') - metric = self.client.component.metrics(job_id=self.job_id, role="guest", - party_id=self.guest_party_id[0], - component_name='evaluation_0') - if not metric.get('data'): - return stdout.get('retcode') - except Exception: - return - - def component_list(self): - try: - stdout = self.client.component.list(job_id=self.job_id) - if stdout.get('retcode'): - self.error_log('component list: {}'.format(stdout.get('retmsg')) + '\n') - dsl_json = get_dict_from_file(self.dsl_path) - if len(stdout.get('data')['components']) == len(list(dsl_json['components'].keys())): - return stdout.get('retcode') - except Exception: - raise - - def table_api(self, command, table_name): - if command == 'table/info': - try: - stdout = self.client.table.info(table_name=table_name['table_name'], namespace=table_name['namespace']) - if stdout.get('retcode'): - self.error_log('table info: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data')['namespace'] == table_name['namespace'] and \ - stdout.get('data')['table_name'] == table_name['table_name']: - return stdout.get('retcode') - except Exception: - return - - elif command == 'table/delete': - try: - stdout = self.client.table.delete(table_name=table_name['table_name'], - namespace=table_name['namespace']) - - if stdout.get('retcode'): - self.error_log('table delete: {}'.format(stdout.get('retmsg')) + '\n') - stdout = self.client.table.delete(table_name=table_name['table_name'], - namespace=table_name['namespace']) - if stdout.get('retcode'): - return 0 - except Exception: - return - - def data_upload(self, upload_path, table_index=None): - upload_file = get_dict_from_file(upload_path) - upload_file['file'] = str(self.data_base_dir.joinpath(upload_file['file']).resolve()) - upload_file['drop'] = 1 - upload_file['use_local_data'] = 0 - if table_index is not None: - upload_file['table_name'] = f'{upload_file["file"]}_{table_index}' - # upload_path = self.cache_directory + 'upload_file.json' - # with open(upload_path, 'w') as fp: - # json.dump(upload_file, fp) - try: - stdout = self.client.data.upload(config_data=upload_file, drop=1) - if stdout.get('retcode'): - self.error_log('data upload: {}'.format(stdout.get('retmsg')) + '\n') - return self.query_status(stdout.get("jobId")) - except Exception: - return - - def data_download(self, table_name): - download_config = { - "table_name": table_name['table_name'], - "namespace": table_name['namespace'], - "output_path": 'download.csv', - } - try: - stdout = self.client.data.download(config_data=download_config) - if stdout.get('retcode'): - self.error_log('data download: {}'.format(stdout.get('retmsg')) + '\n') - return self.query_status(stdout.get("jobId")) - except Exception: - return - - def data_upload_history(self, conf_file): - self.data_upload(conf_file, table_index=1) - try: - stdout = self.client.data.upload_history(limit=2) - if stdout.get('retcode'): - self.error_log('data upload history: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get('data')) == 2: - return stdout.get('retcode') - except Exception: - return - - def tag_api(self, command, tag_name=None, new_tag_name=None): - if command == 'tag/query': - try: - stdout = self.client.tag.query(tag_name=tag_name) - if stdout.get('retcode'): - self.error_log('tag query: {}'.format(stdout.get('retmsg')) + '\n') - if not stdout.get('retcode'): - return stdout.get('data')['tags'][0]['name'] - except Exception: - return - - elif command == 'tag/create': - try: - stdout = self.client.tag.create(tag_name=tag_name) - self.error_log('tag create: {}'.format(stdout.get('retmsg')) + '\n') - if self.tag_api('tag/query', tag_name=tag_name) == tag_name: - return 0 - except Exception: - return - - elif command == 'tag/delete': - try: - stdout = self.client.tag.delete(tag_name=tag_name) - if stdout.get('retcode'): - self.error_log('tag delete: {}'.format(stdout.get('retmsg')) + '\n') - if not self.tag_api('tag/query', tag_name=tag_name): - return 0 - except Exception: - return - - elif command == 'tag/update': - try: - stdout = self.client.tag.update(tag_name=tag_name, new_tag_name=new_tag_name) - self.error_log('tag update: {}'.format(stdout.get('retmsg')) + '\n') - if self.tag_api('tag/query', tag_name=new_tag_name) == new_tag_name: - return 0 - except Exception: - return - - elif command == 'tag/list': - try: - stdout = self.client.tag.list(limit=1) - if stdout.get('retcode'): - self.error_log('tag list: {}'.format(stdout.get('retmsg')) + '\n') - if len(stdout.get('data')['tags']) == 1: - return stdout.get('retcode') - except Exception: - return - - def model_api(self, command, remove_path=None, model_path=None, tag_name=None, homo_deploy_path=None, - homo_deploy_kube_config_path=None, remove=False, model_load_conf=None, servings=None): - if model_load_conf is not None: - model_load_conf["job_parameters"].update({"model_id": self.model_id, - "model_version": self.model_version}) - - if command == 'model/load': - try: - stdout = self.client.model.load(config_data=model_load_conf) - if stdout.get('retcode'): - self.error_log('model load: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/bind': - service_id = "".join([str(i) for i in np.random.randint(9, size=8)]) - model_load_conf.update({"service_id": service_id, "servings": [servings]}) - try: - stdout = self.client.model.bind(config_data=model_load_conf) - if stdout.get('retcode'): - self.error_log('model bind: {}'.format(stdout.get('retmsg')) + '\n') - else: - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/import': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "file": model_path, - "force_update": 1, - } - - try: - remove_path = Path(remove_path + self.model_version) - if os.path.isdir(remove_path): - shutil.rmtree(remove_path) - stdout = self.client.model.import_model(config_data=config_data) - if not stdout.get('retcode') and os.path.isdir(remove_path): - return 0 - else: - self.error_log('model import: {}'.format(stdout.get('retmsg')) + '\n') - except Exception: - return - - elif command == 'model/export': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "output_path": self.output_path - } - # config_file_path = self.cache_directory + 'model_export.json' - # with open(config_file_path, 'w') as fp: - # json.dump(config_data, fp) - stdout = self.client.model.export_model(config_data=config_data) - if stdout.get('retcode'): - self.error_log('model export: {}'.format(stdout.get('retmsg')) + '\n') - else: - export_model_path = stdout.get('file') - return stdout.get('retcode'), export_model_path - - elif command == 'model/migrate': - config_data = { - "job_parameters": { - "federated_mode": "MULTIPLE" - }, - "migrate_initiator": { - "role": "guest", - "party_id": self.guest_party_id[0] - }, - "role": { - "guest": self.guest_party_id, - "arbiter": self.arbiter_party_id, - "host": self.host_party_id - }, - "migrate_role": { - "guest": self.guest_party_id, - "arbiter": self.arbiter_party_id, - "host": self.host_party_id - }, - "execute_party": { - "guest": self.guest_party_id, - "arbiter": self.arbiter_party_id, - "host": self.host_party_id - }, - "model_id": self.model_id, - "model_version": self.model_version, - "unify_model_version": self.job_id + '_01' - } - # config_file_path = self.cache_directory + 'model_migrate.json' - # with open(config_file_path, 'w') as fp: - # json.dump(config_data, fp) - try: - stdout = self.client.model.migrate(config_data=config_data) - if stdout.get('retcode'): - self.error_log('model migrate: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/homo/convert': - config_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - } - config_file_path = self.cache_directory + 'model_homo_convert.json' - with open(config_file_path, 'w') as fp: - json.dump(config_data, fp) - try: - stdout = self.client.model.homo_convert(conf_path=config_file_path) - if stdout.get('retcode'): - self.error_log('model homo convert: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/homo/deploy': - job_data = { - "model_id": self.model_id, - "model_version": self.model_version, - "role": "guest", - "party_id": self.guest_party_id[0], - "component_name": self.component_name - } - config_data = get_dict_from_file(homo_deploy_path) - config_data.update(job_data) - if homo_deploy_kube_config_path: - config_data['deployment_parameters']['config_file'] = homo_deploy_kube_config_path - config_file_path = self.cache_directory + 'model_homo_deploy.json' - with open(config_file_path, 'w') as fp: - json.dump(config_data, fp) - try: - stdout = self.client.model.homo_deploy(conf_path=config_file_path) - if stdout.get('retcode'): - self.error_log('model homo deploy: {}'.format(stdout.get('retmsg')) + '\n') - return stdout.get('retcode') - except Exception: - return - - elif command == 'model_tag/model': - try: - stdout = self.client.model.tag_model(job_id=self.job_id, tag_name=tag_name, remove=remove) - if stdout.get('retcode'): - self.error_log('model tag model: {}'.format(stdout.get('retmsg')) + '\n') - return self.model_api('model_tag/list', tag_name=tag_name, remove=True) - except Exception: - return - - elif command == 'model_tag/list': - try: - stdout = self.client.model.tag_list(job_id=self.job_id) - if stdout.get('retcode'): - self.error_log('model tag retrieve: {}'.format(stdout.get('retmsg')) + '\n') - if remove and len(stdout.get('data').get('tags')) == 0: - return stdout.get('retcode') - if stdout.get('data').get('tags')[0].get('name') == tag_name: - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/deploy': - try: - stdout = self.client.model.deploy(model_id=self.model_id, model_version=self.model_version) - if stdout.get('retcode'): - self.error_log('model deploy: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data')['model_id'] == self.model_id and\ - stdout.get('data')['model_version'] != self.model_version: - self.model_id = stdout.get('data')['model_id'] - self.model_version = stdout.get('data')['model_version'] - self.job_id = stdout.get('data')['model_version'] - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/conf': - try: - stdout = self.client.model.get_predict_conf(model_id=self.model_id, model_version=self.model_version) - if stdout.get('retcode'): - self.error_log('model conf: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data'): - if stdout.get('data')['job_parameters']['common']['model_id'] == self.model_id \ - and stdout.get('data')['job_parameters']['common']['model_version'] == \ - self.model_version and stdout.get('data')['initiator']['party_id'] == \ - self.guest_party_id[0] and stdout.get('data')['initiator']['role'] == 'guest': - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/dsl': - try: - stdout = self.client.model.get_predict_dsl(model_id=self.model_id, model_version=self.model_version) - if stdout.get('retcode'): - self.error_log('model dsl: {}'.format(stdout.get('retmsg')) + '\n') - model_dsl_cpn = list(stdout.get('data')['components'].keys()) - train_dsl_cpn = list(get_dict_from_file(self.dsl_path)['components'].keys()) - if len([k for k in model_dsl_cpn if k in train_dsl_cpn]) == len(train_dsl_cpn): - return stdout.get('retcode') - except Exception: - return - - elif command == 'model/query': - try: - stdout = self.client.model.get_model_info(model_id=self.model_id, model_version=self.model_version, - role="guest", party_id=self.guest_party_id[0]) - if stdout.get('retcode'): - self.error_log('model query: {}'.format(stdout.get('retmsg')) + '\n') - if stdout.get('data')[0].get('f_model_id') == self.model_id and \ - stdout.get('data')[0].get('f_model_version') == self.model_version and \ - stdout.get('data')[0].get('f_role') == "guest" and \ - stdout.get('data')[0].get('f_party_id') == str(self.guest_party_id[0]): - return stdout.get('retcode') - except Exception: - return - - def query_status(self, job_id=None): - while True: - time.sleep(5) - status = self.query_job(job_id=job_id) - if status and status in ["waiting", "running", "success"]: - if status and status == "success": - return 0 - else: - return - - def set_config(self, guest_party_id, host_party_id, arbiter_party_id, path, component_name): - config = get_dict_from_file(path) - config["initiator"]["party_id"] = guest_party_id[0] - config["role"]["guest"] = guest_party_id - config["role"]["host"] = host_party_id - if "arbiter" in config["role"]: - config["role"]["arbiter"] = arbiter_party_id - self.guest_party_id = guest_party_id - self.host_party_id = host_party_id - self.arbiter_party_id = arbiter_party_id - conf_file_path = self.cache_directory + 'conf_file.json' - with open(conf_file_path, 'w') as fp: - json.dump(config, fp) - self.conf_path = conf_file_path - return config['component_parameters']['common'][component_name]['max_iter'] - - -def judging_state(retcode): - if not retcode and retcode is not None: - return 'success' - else: - return 'failed' - - -def run_test_api(config_json, namespace): - output_path = './output/flow_test_data/' - os.makedirs(os.path.dirname(output_path), exist_ok=True) - test_api = TestModel(config_json['data_base_dir'], config_json['server_url'].split('//')[1], - config_json['component_name'], namespace) - test_api.dsl_path = config_json['train_dsl_path'] - test_api.cache_directory = config_json['cache_directory'] - test_api.output_path = str(os.path.abspath(output_path)) + '/' - conf_path = config_json['train_conf_path'] - guest_party_id = config_json['guest_party_id'] - host_party_id = config_json['host_party_id'] - arbiter_party_id = config_json['arbiter_party_id'] - upload_file_path = config_json['upload_file_path'] - model_file_path = config_json['model_file_path'] - conf_file = get_dict_from_file(upload_file_path) - serving_connect_bool = serving_connect(config_json['serving_setting']) - remove_path = str(config_json['data_base_dir']).split("python")[ - 0] + '/fateflow/model_local_cache/guest#{}#arbiter-{}#guest-{}#host-{}#model/'.format( - guest_party_id[0], arbiter_party_id[0], guest_party_id[0], host_party_id[0]) - max_iter = test_api.set_config(guest_party_id, host_party_id, arbiter_party_id, conf_path, - config_json['component_name']) - - data = PrettyTable() - data.set_style(ORGMODE) - data.field_names = ['data api name', 'status'] - data.add_row(['data upload', judging_state(test_api.data_upload(upload_file_path))]) - data.add_row(['data download', judging_state(test_api.data_download(conf_file))]) - data.add_row( - ['data upload history', judging_state(test_api.data_upload_history(upload_file_path))]) - print(data.get_string(title="data api")) - - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = ['table api name', 'status'] - table.add_row(['table info', judging_state(test_api.table_api('table/info', conf_file))]) - table.add_row(['delete table', judging_state(test_api.table_api('table/delete', conf_file))]) - print(table.get_string(title="table api")) - - job = PrettyTable() - job.set_style(ORGMODE) - job.field_names = ['job api name', 'status'] - job.add_row(['job stop', judging_state(test_api.job_api('stop'))]) - job.add_row(['job submit', judging_state(test_api.submit_job(stop=False))]) - job.add_row(['job query', judging_state(False if test_api.query_job() == "success" else True)]) - job.add_row(['job view', judging_state(test_api.job_api('view'))]) - job.add_row(['job list', judging_state(test_api.job_api('list/job'))]) - job.add_row(['job config', judging_state(test_api.job_config(max_iter=max_iter))]) - job.add_row(['job log', judging_state(test_api.job_api('log'))]) - job.add_row(['job dsl generate', judging_state(test_api.job_dsl_generate())]) - print(job.get_string(title="job api")) - - task = PrettyTable() - task.set_style(ORGMODE) - task.field_names = ['task api name', 'status'] - task.add_row(['task list', judging_state(test_api.list_task())]) - task.add_row(['task query', judging_state(test_api.query_task())]) - print(task.get_string(title="task api")) - - tag = PrettyTable() - tag.set_style(ORGMODE) - tag.field_names = ['tag api name', 'status'] - tag.add_row(['create tag', judging_state(test_api.tag_api('tag/create', 'create_job_tag'))]) - tag.add_row(['update tag', judging_state(test_api.tag_api('tag/update', 'create_job_tag', 'update_job_tag'))]) - tag.add_row(['list tag', judging_state(test_api.tag_api('tag/list'))]) - tag.add_row( - ['query tag', judging_state(not test_api.tag_api('tag/query', 'update_job_tag') == 'update_job_tag')]) - tag.add_row(['delete tag', judging_state(test_api.tag_api('tag/delete', 'update_job_tag'))]) - print(tag.get_string(title="tag api")) - - component = PrettyTable() - component.set_style(ORGMODE) - component.field_names = ['component api name', 'status'] - component.add_row(['output data', judging_state(test_api.component_api('output/data'))]) - component.add_row(['output table', judging_state(test_api.component_api('output/data/table'))]) - component.add_row(['output model', judging_state(test_api.component_api('output/model'))]) - component.add_row(['component parameters', judging_state(test_api.component_api('parameters', max_iter=max_iter))]) - component.add_row(['component summary', judging_state(test_api.component_api('summary'))]) - component.add_row(['component list', judging_state(test_api.component_list())]) - component.add_row(['metrics', judging_state(test_api.component_api('metrics'))]) - component.add_row(['metrics all', judging_state(test_api.component_api('metric/all'))]) - - model = PrettyTable() - model.set_style(ORGMODE) - model.field_names = ['model api name', 'status'] - if config_json.get('component_is_homo'): - homo_deploy_path = config_json.get('homo_deploy_path') - homo_deploy_kube_config_path = config_json.get('homo_deploy_kube_config_path') - model.add_row(['model homo convert', judging_state(test_api.model_api('model/homo/convert'))]) - model.add_row(['model homo deploy', - judging_state(test_api.model_api('model/homo/deploy', - homo_deploy_path=homo_deploy_path, - homo_deploy_kube_config_path=homo_deploy_kube_config_path))]) - if not config_json.get('component_is_homo') and serving_connect_bool: - model_load_conf = get_dict_from_file(model_file_path) - model_load_conf["initiator"]["party_id"] = guest_party_id - model_load_conf["role"].update( - {"guest": [guest_party_id], "host": [host_party_id], "arbiter": [arbiter_party_id]}) - model.add_row(['model load', judging_state(test_api.model_api('model/load', model_load_conf=model_load_conf))]) - model.add_row(['model bind', judging_state(test_api.model_api('model/bind', model_load_conf=model_load_conf, - servings=config_json['serving_setting']))]) - status, model_path = test_api.model_api('model/export') - model.add_row(['model export', judging_state(status)]) - model.add_row(['model import', (judging_state( - test_api.model_api('model/import', remove_path=remove_path, model_path=model_path)))]) - model.add_row(['tag model', judging_state(test_api.model_api('model_tag/model', tag_name='model_tag_create'))]) - model.add_row(['tag list', judging_state(test_api.model_api('model_tag/list', tag_name='model_tag_create'))]) - model.add_row( - ['tag remove', judging_state(test_api.model_api('model_tag/model', tag_name='model_tag_create', remove=True))]) - if serving_connect_bool: - model.add_row( - ['model migrate', judging_state(test_api.model_api('model/migrate'))]) - model.add_row(['model query', judging_state(test_api.model_api('model/query'))]) - if not config_json.get('component_is_homo') and serving_connect_bool: - model.add_row(['model deploy', judging_state(test_api.model_api('model/deploy'))]) - model.add_row(['model conf', judging_state(test_api.model_api('model/conf'))]) - model.add_row(['model dsl', judging_state(test_api.model_api('model/dsl'))]) - print(model.get_string(title="model api")) - component.add_row(['metrics delete', judging_state(test_api.component_api('metric/delete'))]) - print(component.get_string(title="component api")) - - queue = PrettyTable() - queue.set_style(ORGMODE) - queue.field_names = ['api name', 'status'] - test_api.submit_job() - test_api.submit_job() - test_api.submit_job() - queue.add_row(['clean/queue', judging_state(test_api.job_api('clean/queue'))]) - print(queue.get_string(title="queue job")) - print('Please check the error content: {}'.format(test_api.error_log(None))) diff --git a/python/fate_test/fate_test/scripts/__init__.py b/python/fate_test/fate_test/scripts/__init__.py deleted file mode 100644 index 878d3a9c5d..0000000000 --- a/python/fate_test/fate_test/scripts/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/python/fate_test/fate_test/scripts/_options.py b/python/fate_test/fate_test/scripts/_options.py deleted file mode 100644 index 7ea0c54d18..0000000000 --- a/python/fate_test/fate_test/scripts/_options.py +++ /dev/null @@ -1,65 +0,0 @@ -import time - -import click -from fate_test._config import parse_config, default_config -from fate_test.scripts._utils import _set_namespace - - -class SharedOptions(object): - _options = { - "config": (('-c', '--config'), - dict(type=click.Path(exists=True), help=f"Manual specify config file", default=None), - default_config().__str__()), - "namespace": (('-n', '--namespace'), - dict(type=str, help=f"Manual specify fate_test namespace", default=None), - time.strftime('%Y%m%d%H%M%S')), - "namespace_mangling": (('-nm', '--namespace-mangling',), - dict(type=bool, is_flag=True, help="Mangling data namespace", default=None), - False), - "yes": (('-y', '--yes',), dict(type=bool, is_flag=True, help="Skip double check", default=None), - False), - "extend_sid": (('--extend_sid', ), - dict(type=bool, is_flag=True, help="whether to append uuid as sid when uploading data", - default=None), False), - "auto_increasing_sid": (('--auto_increasing_sid', ), - dict(type=bool, is_flag=True, help="whether to generate sid value starting at 0", - default=None), False), - } - - def __init__(self): - self._options_kwargs = {} - - def __getitem__(self, item): - return self._options_kwargs[item] - - def get(self, k, default=None): - v = self._options_kwargs.get(k, default) - if v is None and k in self._options: - v = self._options[k][2] - return v - - def update(self, **kwargs): - for k, v in kwargs.items(): - if v is not None: - self._options_kwargs[k] = v - - def post_process(self): - # add defaults here - for k, v in self._options.items(): - if self._options_kwargs.get(k, None) is None: - self._options_kwargs[k] = v[2] - - # update config - config = parse_config(self._options_kwargs['config']) - self._options_kwargs['config'] = config - - _set_namespace(self._options_kwargs['namespace_mangling'], self._options_kwargs['namespace']) - - @classmethod - def get_shared_options(cls, hidden=False): - def shared_options(f): - for name, option in cls._options.items(): - f = click.option(*option[0], **dict(option[1], hidden=hidden))(f) - return f - - return shared_options diff --git a/python/fate_test/fate_test/scripts/_utils.py b/python/fate_test/fate_test/scripts/_utils.py deleted file mode 100644 index 520478c3f9..0000000000 --- a/python/fate_test/fate_test/scripts/_utils.py +++ /dev/null @@ -1,187 +0,0 @@ -import importlib -import os -import time -import uuid -import glob as glob_ -from pathlib import Path - -import click -from fate_test import _config -from fate_test._client import Clients -from fate_test._config import Config -from fate_test._flow_client import DataProgress, UploadDataResponse, QueryJobResponse -from fate_test._io import echo, LOGGER, set_logger -from fate_test._parser import Testsuite, BenchmarkSuite, DATA_JSON_HOOK, CONF_JSON_HOOK, DSL_JSON_HOOK - - -def _big_data_task(includes, guest_data_size, host_data_size, guest_feature_num, host_feature_num, host_data_type, - config_inst, encryption_type, match_rate, sparsity, force, split_host, output_path, parallelize): - from fate_test.scripts import generate_mock_data - - def _find_testsuite_files(path): - suffix = ["testsuite.json", "benchmark.json"] - if isinstance(path, str): - path = Path(path) - if path.is_file(): - if path.name.endswith(suffix[0]) or path.name.endswith(suffix[1]): - paths = [path] - else: - LOGGER.warning(f"{path} is file, but not end with `{suffix}`, skip") - paths = [] - return [p.resolve() for p in paths] - else: - os.path.abspath(path) - paths = glob_.glob(f"{path}/*{suffix[0]}") + glob_.glob(f"{path}/*{suffix[1]}") - return [Path(p) for p in paths] - - for include in includes: - if isinstance(include, str): - include_paths = Path(include) - include_paths = _find_testsuite_files(include_paths) - for include_path in include_paths: - generate_mock_data.get_big_data(guest_data_size, host_data_size, guest_feature_num, host_feature_num, - include_path, host_data_type, config_inst, encryption_type, - match_rate, sparsity, force, split_host, output_path, parallelize) - - -def _load_testsuites(includes, excludes, glob, provider=None, suffix="testsuite.json", suite_type="testsuite"): - def _find_testsuite_files(path): - if isinstance(path, str): - path = Path(path) - if path.is_file(): - if path.name.endswith(suffix): - paths = [path] - else: - LOGGER.warning(f"{path} is file, but not end with `{suffix}`, skip") - paths = [] - else: - paths = path.glob(f"**/*{suffix}") - return [p.resolve() for p in paths] - - excludes_set = set() - for exclude in excludes: - excludes_set.update(_find_testsuite_files(exclude)) - - suite_paths = set() - for include in includes: - if isinstance(include, str): - include = Path(include) - - # glob - if glob is not None and include.is_dir(): - include_list = include.glob(glob) - else: - include_list = [include] - for include_path in include_list: - for suite_path in _find_testsuite_files(include_path): - if suite_path not in excludes_set: - suite_paths.add(suite_path) - suites = [] - for suite_path in suite_paths: - try: - if suite_type == "testsuite": - suite = Testsuite.load(suite_path.resolve(), provider) - elif suite_type == "benchmark": - suite = BenchmarkSuite.load(suite_path.resolve()) - else: - raise ValueError(f"Unsupported suite type: {suite_type}. Only accept type 'testsuite' or 'benchmark'.") - except Exception as e: - echo.stdout(f"load suite {suite_path} failed: {e}") - else: - suites.append(suite) - return suites - - -@LOGGER.catch -def _upload_data(clients: Clients, suite, config: Config, output_path=None): - with click.progressbar(length=len(suite.dataset), - label="dataset", - show_eta=False, - show_pos=True, - width=24) as bar: - for i, data in enumerate(suite.dataset): - data.update(config) - table_name = data.config['table_name'] if data.config.get( - 'table_name', None) is not None else data.config.get('name') - data_progress = DataProgress(f"{data.role_str}<-{data.config['namespace']}.{table_name}") - - def update_bar(n_step): - bar.item_show_func = lambda x: data_progress.show() - time.sleep(0.1) - bar.update(n_step) - - def _call_back(resp): - if isinstance(resp, UploadDataResponse): - data_progress.submitted(resp.job_id) - echo.file(f"[dataset]{resp.job_id}") - if isinstance(resp, QueryJobResponse): - data_progress.update() - update_bar(0) - - try: - echo.stdout_newline() - status, data_path = clients[data.role_str].upload_data(data, _call_back, output_path) - time.sleep(1) - data_progress.update() - if status != 'success': - raise RuntimeError(f"uploading {i + 1}th data for {suite.path} {status}") - bar.update(1) - if _config.data_switch: - from fate_test.scripts import generate_mock_data - - generate_mock_data.remove_file(data_path) - except Exception: - exception_id = str(uuid.uuid1()) - echo.file(f"exception({exception_id})") - LOGGER.exception(f"exception id: {exception_id}") - echo.echo(f"upload {i + 1}th data {data.config} to {data.role_str} fail, exception_id: {exception_id}") - # raise RuntimeError(f"exception uploading {i + 1}th data") from e - - -def _delete_data(clients: Clients, suite: Testsuite): - with click.progressbar(length=len(suite.dataset), - label="delete ", - show_eta=False, - show_pos=True, - width=24) as bar: - for data in suite.dataset: - # noinspection PyBroadException - try: - table_name = data.config['table_name'] if data.config.get( - 'table_name', None) is not None else data.config.get('name') - bar.item_show_func = \ - lambda x: f"delete table: name={table_name}, namespace={data.config['namespace']}" - clients[data.role_str].delete_data(data) - except Exception: - LOGGER.exception( - f"delete failed: name={table_name}, namespace={data.config['namespace']}") - - time.sleep(0.5) - bar.update(1) - echo.stdout_newline() - - -def _load_module_from_script(script_path): - module_name = str(script_path).split("/", -1)[-1].split(".")[0] - loader = importlib.machinery.SourceFileLoader(module_name, str(script_path)) - spec = importlib.util.spec_from_loader(loader.name, loader) - mod = importlib.util.module_from_spec(spec) - loader.exec_module(mod) - return mod - - -def _set_namespace(data_namespace_mangling, namespace): - Path(f"logs/{namespace}").mkdir(exist_ok=True, parents=True) - set_logger(f"logs/{namespace}/exception.log") - echo.set_file(click.open_file(f'logs/{namespace}/stdout', "a")) - - if data_namespace_mangling: - echo.echo(f"add data_namespace_mangling: _{namespace}") - DATA_JSON_HOOK.add_extend_namespace_hook(namespace) - CONF_JSON_HOOK.add_extend_namespace_hook(namespace) - - -def _add_replace_hook(replace): - DATA_JSON_HOOK.add_replace_hook(replace) - CONF_JSON_HOOK.add_replace_hook(replace) - DSL_JSON_HOOK.add_replace_hook(replace) diff --git a/python/fate_test/fate_test/scripts/benchmark_cli.py b/python/fate_test/fate_test/scripts/benchmark_cli.py deleted file mode 100644 index f814cc5876..0000000000 --- a/python/fate_test/fate_test/scripts/benchmark_cli.py +++ /dev/null @@ -1,149 +0,0 @@ -import os -import re -import time -import uuid -from datetime import timedelta -from inspect import signature - -import click -from fate_test._client import Clients -from fate_test._config import Config -from fate_test._io import LOGGER, echo -from fate_test._parser import BenchmarkSuite -from fate_test.scripts._options import SharedOptions -from fate_test.scripts._utils import _upload_data, _delete_data, _load_testsuites, _load_module_from_script -from fate_test.utils import show_data, match_metrics - -DATA_DISPLAY_PATTERN = re.compile("^FATE") - - -@click.command(name="benchmark-quality") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), multiple=True, metavar="", - help="include *benchmark.json under these paths") -@click.option('-e', '--exclude', type=click.Path(exists=True), multiple=True, - help="exclude *benchmark.json under these paths") -@click.option('-g', '--glob', type=str, - help="glob string to filter sub-directory of path specified by ") -@click.option('-t', '--tol', type=float, - help="tolerance (absolute error) for metrics to be considered almost equal. " - "Comparison is done by evaluating abs(a-b) <= max(relative_tol * max(abs(a), abs(b)), absolute_tol)") -@click.option('-s', '--storage-tag', type=str, - help="tag for storing metrics, for future metrics info comparison") -@click.option('-v', '--history-tag', type=str, multiple=True, - help="Extract metrics info from history tags for comparison") -@click.option('-d', '--match-details', type=click.Choice(['all', 'relative', 'absolute', 'none']), - default="all", help="Error value display in algorithm comparison") -@click.option('--skip-data', is_flag=True, default=False, - help="skip uploading data specified in benchmark conf") -@click.option("--disable-clean-data", "clean_data", flag_value=False, default=None) -@click.option("--enable-clean-data", "clean_data", flag_value=True, default=None) -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def run_benchmark(ctx, include, exclude, glob, skip_data, tol, clean_data, storage_tag, history_tag, match_details, - **kwargs): - """ - process benchmark suite, alias: bq - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - config_inst.extend_sid = ctx.obj["extend_sid"] - config_inst.auto_increasing_sid = ctx.obj["auto_increasing_sid"] - if clean_data is None: - clean_data = config_inst.clean_data - data_namespace_mangling = ctx.obj["namespace_mangling"] - yes = ctx.obj["yes"] - - echo.welcome("benchmark") - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo("loading testsuites:") - suites = _load_testsuites(includes=include, excludes=exclude, glob=glob, - suffix="benchmark.json", suite_type="benchmark") - for suite in suites: - echo.echo(f"\tdataset({len(suite.dataset)}) benchmark groups({len(suite.pairs)}) {suite.path}") - if not yes and not click.confirm("running?"): - return - with Clients(config_inst) as client: - fate_version = client["guest_0"].get_version() - for i, suite in enumerate(suites): - # noinspection PyBroadException - try: - start = time.time() - echo.echo(f"[{i + 1}/{len(suites)}]start at {time.strftime('%Y-%m-%d %X')} {suite.path}", fg='red') - if not skip_data: - try: - _upload_data(client, suite, config_inst) - except Exception as e: - raise RuntimeError(f"exception occur while uploading data for {suite.path}") from e - try: - _run_benchmark_pairs(config_inst, suite, tol, namespace, data_namespace_mangling, storage_tag, - history_tag, fate_version, match_details) - except Exception as e: - raise RuntimeError(f"exception occur while running benchmark jobs for {suite.path}") from e - - if not skip_data and clean_data: - _delete_data(client, suite) - echo.echo(f"[{i + 1}/{len(suites)}]elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception in {suite.path}, exception_id={exception_id}", err=True, fg='red') - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@LOGGER.catch -def _run_benchmark_pairs(config: Config, suite: BenchmarkSuite, tol: float, namespace: str, - data_namespace_mangling: bool, storage_tag, history_tag, fate_version, match_details): - # pipeline demo goes here - pair_n = len(suite.pairs) - fate_base = config.fate_base - PYTHONPATH = os.environ.get('PYTHONPATH') + ":" + os.path.join(fate_base, "python") - os.environ['PYTHONPATH'] = PYTHONPATH - for i, pair in enumerate(suite.pairs): - echo.echo(f"Running [{i + 1}/{pair_n}] group: {pair.pair_name}") - results = {} - # data_summary = None - job_n = len(pair.jobs) - for j, job in enumerate(pair.jobs): - try: - echo.echo(f"Running [{j + 1}/{job_n}] job: {job.job_name}") - job_name, script_path, conf_path = job.job_name, job.script_path, job.conf_path - param = Config.load_from_file(conf_path) - mod = _load_module_from_script(script_path) - input_params = signature(mod.main).parameters - # local script - if len(input_params) == 1: - data, metric = mod.main(param=param) - elif len(input_params) == 2: - data, metric = mod.main(config=config, param=param) - # pipeline script - elif len(input_params) == 3: - if data_namespace_mangling: - data, metric = mod.main(config=config, param=param, namespace=f"_{namespace}") - else: - data, metric = mod.main(config=config, param=param) - else: - data, metric = mod.main() - results[job_name] = metric - echo.echo(f"[{j + 1}/{job_n}] job: {job.job_name} Success!\n") - if data and DATA_DISPLAY_PATTERN.match(job_name): - # data_summary = data - show_data(data) - # if data_summary is None: - # data_summary = data - except Exception as e: - exception_id = uuid.uuid1() - echo.echo(f"exception while running [{j + 1}/{job_n}] job, exception_id={exception_id}", err=True, - fg='red') - LOGGER.exception(f"exception id: {exception_id}, error message: \n{e}") - continue - rel_tol = pair.compare_setting.get("relative_tol") - # show_data(data_summary) - match_metrics(evaluate=True, group_name=pair.pair_name, abs_tol=tol, rel_tol=rel_tol, - storage_tag=storage_tag, history_tag=history_tag, fate_version=fate_version, - cache_directory=config.cache_directory, match_details=match_details, **results) diff --git a/python/fate_test/fate_test/scripts/cli.py b/python/fate_test/fate_test/scripts/cli.py deleted file mode 100644 index 194c464885..0000000000 --- a/python/fate_test/fate_test/scripts/cli.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import click -from fate_test.scripts._options import SharedOptions -from fate_test.scripts.benchmark_cli import run_benchmark -from fate_test.scripts.config_cli import config_group -from fate_test.scripts.data_cli import data_group -from fate_test.scripts.testsuite_cli import run_suite -from fate_test.scripts.performance_cli import run_task -from fate_test.scripts.flow_test_cli import flow_group -from fate_test.scripts.quick_test_cli import unittest_group -from fate_test.scripts.secure_protocol_cli import secure_protocol_group -from fate_test.scripts.pipeline_conversion_cli import convert_group - -commands = { - "config": config_group, - "suite": run_suite, - "performance": run_task, - "benchmark-quality": run_benchmark, - "data": data_group, - "flow-test": flow_group, - "unittest": unittest_group, - "convert": convert_group, - "op-test": secure_protocol_group -} - -commands_alias = { - "bq": "benchmark-quality", - "bp": "performance" -} - - -class MultiCLI(click.MultiCommand): - - def list_commands(self, ctx): - return list(commands) - - def get_command(self, ctx, name): - if name not in commands and name in commands_alias: - name = commands_alias[name] - if name not in commands: - ctx.fail("No such command '{}'.".format(name)) - return commands[name] - - -@click.command(cls=MultiCLI, help="A collection of useful tools to running FATE's test.", - context_settings=dict(help_option_names=["-h", "--help"])) -@SharedOptions.get_shared_options() -@click.pass_context -def cli(ctx, **kwargs): - ctx.ensure_object(SharedOptions) - ctx.obj.update(**kwargs) - - -if __name__ == '__main__': - cli(obj=SharedOptions()) diff --git a/python/fate_test/fate_test/scripts/config_cli.py b/python/fate_test/fate_test/scripts/config_cli.py deleted file mode 100644 index 55f0b4c61a..0000000000 --- a/python/fate_test/fate_test/scripts/config_cli.py +++ /dev/null @@ -1,79 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from pathlib import Path - -import click -from fate_test._client import Clients -from fate_test._config import create_config, default_config, parse_config -from fate_test.scripts._options import SharedOptions - - -@click.group("config", help="fate_test config") -def config_group(): - """ - config fate_test - """ - pass - - -@config_group.command(name="new") -def _new(): - """ - create new fate_test config temperate - """ - create_config(Path("fate_test_config.yaml")) - click.echo(f"create config file: fate_test_config.yaml") - - -@config_group.command(name="edit") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def _edit(ctx, **kwargs): - """ - edit fate_test config file - """ - ctx.obj.update(**kwargs) - config = ctx.obj.get("config") - click.edit(filename=config) - - -@config_group.command(name="show") -def _show(): - """ - show fate_test default config path - """ - click.echo(f"default config path is {default_config()}") - - -@config_group.command(name="check") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def _config(ctx, **kwargs): - """ - check connection - """ - ctx.obj.update(**kwargs) - config_inst = parse_config(ctx.obj.get("config")) - with Clients(config_inst) as clients: - roles = clients.all_roles() - for r in roles: - try: - version, address = clients[r].check_connection() - except Exception as e: - click.echo(f"[X]connection fail, role is {r}, exception is {e.args}") - else: - click.echo(f"[✓]connection {address} ok, fate version is {version}, role is {r}") diff --git a/python/fate_test/fate_test/scripts/data_cli.py b/python/fate_test/fate_test/scripts/data_cli.py deleted file mode 100644 index fe6f381d93..0000000000 --- a/python/fate_test/fate_test/scripts/data_cli.py +++ /dev/null @@ -1,426 +0,0 @@ -import os -import re -import sys -import time -import uuid -import json -from datetime import timedelta - -import click -from pathlib import Path -from ruamel import yaml - -from fate_test import _config -from fate_test._config import Config -from fate_test._client import Clients -from fate_test._io import LOGGER, echo -from fate_test.scripts._options import SharedOptions -from fate_test.scripts._utils import _upload_data, _load_testsuites, _delete_data, _big_data_task - - -@click.group(name="data") -def data_group(): - """ - upload or delete data in suite config files - """ - ... - - -@data_group.command("upload") -@click.option('-i', '--include', required=False, type=click.Path(exists=True), multiple=True, metavar="", - help="include *benchmark.json under these paths") -@click.option('-e', '--exclude', type=click.Path(exists=True), multiple=True, - help="exclude *benchmark.json under these paths") -@click.option("-t", "--config-type", type=click.Choice(["min_test", "all_examples"]), default="min_test", - help="config file") -@click.option('-g', '--glob', type=str, - help="glob string to filter sub-directory of path specified by ") -@click.option('-s', '--suite-type', required=False, type=click.Choice(["testsuite", "benchmark"]), default="testsuite", - help="suite type") -@click.option('-r', '--role', type=str, default='all', help="role to process, default to `all`. " - "use option likes: `guest_0`, `host_0`, `host`") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def upload(ctx, include, exclude, glob, suite_type, role, config_type, **kwargs): - """ - upload data defined in suite config files - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - config_inst.extend_sid = ctx.obj["extend_sid"] - config_inst.auto_increasing_sid = ctx.obj["auto_increasing_sid"] - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - if len(include) != 0: - echo.echo("loading testsuites:") - suffix = "benchmark.json" if suite_type == "benchmark" else "testsuite.json" - suites = _load_testsuites(includes=include, excludes=exclude, glob=glob, - suffix=suffix, suite_type=suite_type) - for suite in suites: - if role != "all": - suite.dataset = [d for d in suite.dataset if re.match(d.role_str, role)] - echo.echo(f"\tdataset({len(suite.dataset)}) {suite.path}") - if not yes and not click.confirm("running?"): - return - client_upload(suites=suites, config_inst=config_inst, namespace=namespace) - else: - config = get_config(config_inst) - if config_type == 'min_test': - config_file = config.min_test_data_config - else: - config_file = config.all_examples_data_config - - with open(config_file, 'r', encoding='utf-8') as f: - upload_data = json.loads(f.read()) - - echo.echo(f"\tdataset({len(upload_data['data'])}) {config_file}") - if not yes and not click.confirm("running?"): - return - with Clients(config_inst) as client: - data_upload(client, config_inst, upload_data) - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@data_group.command("delete") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), multiple=True, metavar="", - help="include *benchmark.json under these paths") -@click.option('-e', '--exclude', type=click.Path(exists=True), multiple=True, - help="exclude *benchmark.json under these paths") -@click.option('-g', '--glob', type=str, - help="glob string to filter sub-directory of path specified by ") -@click.option('-s', '--suite-type', required=True, type=click.Choice(["testsuite", "benchmark"]), help="suite type") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def delete(ctx, include, exclude, glob, yes, suite_type, **kwargs): - """ - delete data defined in suite config files - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo("loading testsuites:") - suffix = "benchmark.json" if suite_type == "benchmark" else "testsuite.json" - - suites = _load_testsuites(includes=include, excludes=exclude, glob=glob, - suffix=suffix, suite_type=suite_type) - if not yes and not click.confirm("running?"): - return - - for suite in suites: - echo.echo(f"\tdataset({len(suite.dataset)}) {suite.path}") - if not yes and not click.confirm("running?"): - return - with Clients(config_inst) as client: - for i, suite in enumerate(suites): - _delete_data(client, suite) - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@data_group.command("generate") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), multiple=True, metavar="", - help="include *testsuite.json / *benchmark.json under these paths") -@click.option('-ht', '--host-data-type', default='tag_value', type=click.Choice(['dense', 'tag', 'tag_value']), - help="Select the format of the host data") -@click.option('-p', '--encryption-type', type=click.Choice(['sha256', 'md5']), - help="Entry ID encryption method for, sha256 and md5") -@click.option('-m', '--match-rate', default=1.0, type=float, - help="Intersection rate relative to guest, between [0, 1]") -@click.option('-s', '--sparsity', default=0.2, type=float, - help="The sparsity of tag data, The value is between (0-1)") -@click.option('-ng', '--guest-data-size', type=int, default=10000, - help="Set guest data set size, not less than 100") -@click.option('-nh', '--host-data-size', type=int, - help="Set host data set size, not less than 100") -@click.option('-fg', '--guest-feature-num', type=int, default=20, - help="Set guest feature dimensions") -@click.option('-fh', '--host-feature-num', type=int, default=200, - help="Set host feature dimensions; the default is equal to the number of guest's size") -@click.option('-o', '--output-path', type=click.Path(exists=True), - help="Customize the output path of generated data") -@click.option('--force', is_flag=True, default=False, - help="Overwrite existing file") -@click.option('--split-host', is_flag=True, default=False, - help="Divide the amount of host data equally among all the host tables in TestSuite") -@click.option('--upload-data', is_flag=True, default=False, - help="Generated data will be uploaded") -@click.option('--remove-data', is_flag=True, default=False, - help="The generated data will be deleted") -@click.option('--parallelize', is_flag=True, default=False, - help="It is directly used to upload data, and will not generate data") -@click.option('--use-local-data', is_flag=True, default=False, - help="The existing data of the server will be uploaded, This parameter is not recommended for " - "distributed applications") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def generate(ctx, include, host_data_type, encryption_type, match_rate, sparsity, guest_data_size, - host_data_size, guest_feature_num, host_feature_num, output_path, force, split_host, upload_data, - remove_data, use_local_data, parallelize, **kwargs): - """ - create data defined in suite config files - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - config_inst.extend_sid = ctx.obj["extend_sid"] - config_inst.auto_increasing_sid = ctx.obj["auto_increasing_sid"] - if parallelize and upload_data: - upload_data = False - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo("loading testsuites:") - if host_data_size is None: - host_data_size = guest_data_size - suites = _load_testsuites(includes=include, excludes=tuple(), glob=None) - suites += _load_testsuites(includes=include, excludes=tuple(), glob=None, - suffix="benchmark.json", suite_type="benchmark") - for suite in suites: - if upload_data: - echo.echo(f"\tdataget({len(suite.dataset)}) dataset({len(suite.dataset)}) {suite.path}") - else: - echo.echo(f"\tdataget({len(suite.dataset)}) {suite.path}") - if not yes and not click.confirm("running?"): - return - - _big_data_task(include, guest_data_size, host_data_size, guest_feature_num, host_feature_num, host_data_type, - config_inst, encryption_type, match_rate, sparsity, force, split_host, output_path, parallelize) - if upload_data: - if use_local_data: - _config.use_local_data = 0 - _config.data_switch = remove_data - client_upload(suites=suites, config_inst=config_inst, namespace=namespace, output_path=output_path) - - -@data_group.command("download") -@click.option("-t", "--type", type=click.Choice(["mnist"]), default="mnist", - help="config file") -@click.option('-o', '--output-path', type=click.Path(exists=True), - help="output path of mnist data, the default path is examples/data") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def download_mnists(ctx, output_path, **kwargs): - """ - download mnist data for flow - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if output_path is None: - config = get_config(config_inst) - output_path = str(config.data_base_dir) + "/examples/data/" - if not yes and not click.confirm("running?"): - return - try: - download_mnist(Path(output_path), "mnist_train") - download_mnist(Path(output_path), "mnist_eval", is_train=False) - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@data_group.command("query_schema") -@click.option('-cpn', '--component-name', required=False, type=str, help="component name", default='dataio_0') -@click.option('-j', '--job-id', required=True, type=str, help="job id") -@click.option('-r', '--role', required=True, type=click.Choice(["guest", "host", "arbiter"]), help="job id") -@click.option('-p', '--party-id', required=True, type=str, help="party id") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def query_schema(ctx, component_name, job_id, role, party_id, **kwargs): - """ - query the meta of the output data of a component - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - yes = ctx.obj["yes"] - config_inst = ctx.obj["config"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - with Clients(config_inst) as client: - query_component_output_data(client, config_inst, component_name, job_id, role, party_id) - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -def get_config(conf: Config): - return conf - - -def query_component_output_data(clients: Clients, config: Config, component_name, job_id, role, party_id): - roles = config.role - clients_role = None - for k, v in roles.items(): - if int(party_id) in v and k == role: - clients_role = role + "_" + str(v.index(int(party_id))) - try: - if clients_role is None: - raise ValueError(f"party id {party_id} does not exist") - - try: - table_info = clients[clients_role].output_data_table(job_id=job_id, role=role, party_id=party_id, - component_name=component_name) - table_info = clients[clients_role].table_info(table_name=table_info['name'], - namespace=table_info['namespace']) - except Exception as e: - raise RuntimeError(f"An exception occurred while getting data {clients_role}<-{component_name}") from e - - echo.echo("query_component_output_data result: {}".format(table_info)) - try: - header = table_info['data']['schema']['header'] - except ValueError as e: - raise ValueError(f"Obtain header from table error, error msg: {e}") - - result = [] - for idx, header_name in enumerate(header[1:]): - result.append((idx, header_name)) - echo.echo("Queried header is {}".format(result)) - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - - -def download_mnist(base, name, is_train=True): - import torchvision - - dataset = torchvision.datasets.MNIST( - root=base.joinpath(".cache"), train=is_train, download=True - ) - converted_path = base.joinpath(name) - converted_path.mkdir(exist_ok=True) - - inputs_path = converted_path.joinpath("images") - inputs_path.mkdir(exist_ok=True) - targets_path = converted_path.joinpath("targets") - config_path = converted_path.joinpath("config.yaml") - filenames_path = converted_path.joinpath("filenames") - - with filenames_path.open("w") as filenames: - with targets_path.open("w") as targets: - for idx, (img, target) in enumerate(dataset): - filename = f"{idx:05d}" - # save img - img.save(inputs_path.joinpath(f"{filename}.jpg")) - # save target - targets.write(f"{filename},{target}\n") - # save filenames - filenames.write(f"{filename}\n") - - config = { - "type": "vision", - "inputs": {"type": "images", "ext": "jpg", "PIL_mode": "L"}, - "targets": {"type": "integer"}, - } - with config_path.open("w") as f: - yaml.safe_dump(config, f, indent=2, default_flow_style=False) - - -def client_upload(suites, config_inst, namespace, output_path=None): - with Clients(config_inst) as client: - for i, suite in enumerate(suites): - # noinspection PyBroadException - try: - echo.echo(f"[{i + 1}/{len(suites)}]start at {time.strftime('%Y-%m-%d %X')} {suite.path}", fg='red') - try: - _upload_data(client, suite, config_inst, output_path) - except Exception as e: - raise RuntimeError(f"exception occur while uploading data for {suite.path}") from e - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception in {suite.path}, exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -def data_upload(clients: Clients, conf: Config, upload_config): - def _await_finish(job_id, task_name=None): - deadline = time.time() + sys.maxsize - start = time.time() - param = dict( - job_id=job_id, - role=None - ) - while True: - stdout = clients["guest_0"].flow_client("job/query", param) - status = stdout["data"][0]["f_status"] - elapse_seconds = int(time.time() - start) - date = time.strftime('%Y-%m-%d %X') - if task_name: - log_msg = f"[{date}][{task_name}]{status}, elapse: {timedelta(seconds=elapse_seconds)}" - else: - log_msg = f"[{date}]{job_id} {status}, elapse: {timedelta(seconds=elapse_seconds)}" - if (status == "running" or status == "waiting") and time.time() < deadline: - print(log_msg, end="\r") - time.sleep(1) - continue - else: - print(" " * 60, end="\r") # clean line - echo.echo(log_msg) - return status - - task_data = upload_config["data"] - for i, data in enumerate(task_data): - format_msg = f"@{data['file']} >> {data['namespace']}.{data['table_name']}" - echo.echo(f"[{time.strftime('%Y-%m-%d %X')}]uploading {format_msg}") - try: - data["file"] = str(os.path.join(conf.data_base_dir, data["file"])) - param = dict( - file=data["file"], - head=data["head"], - partition=data["partition"], - table_name=data["table_name"], - namespace=data["namespace"] - ) - stdout = clients["guest_0"].flow_client("data/upload", param, drop=1) - job_id = stdout.get('jobId', None) - echo.echo(f"[{time.strftime('%Y-%m-%d %X')}]upload done {format_msg}, job_id={job_id}\n") - if job_id is None: - echo.echo("table already exist. To upload again, Please add '-f 1' in start cmd") - continue - _await_finish(job_id) - param = dict( - table_name=data["table_name"], - namespace=data["namespace"] - ) - stdout = clients["guest_0"].flow_client("table/info", param) - - count = stdout["data"]["count"] - if count != data["count"]: - raise AssertionError("Count of upload file is not as expect, count is: {}," - "expect is: {}".format(count, data["count"])) - echo.echo(f"[{time.strftime('%Y-%m-%d %X')}] check_data_out {stdout} \n") - except Exception as e: - exception_id = uuid.uuid1() - echo.echo(f"exception in {data['file']}, exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - echo.echo(f"upload {i + 1}th data {data['table_name']} fail, exception_id: {exception_id}") - # raise RuntimeError(f"exception occur while uploading data for {data['file']}") from e - finally: - echo.stdout_newline() diff --git a/python/fate_test/fate_test/scripts/flow_test_cli.py b/python/fate_test/fate_test/scripts/flow_test_cli.py deleted file mode 100644 index fd34e4424b..0000000000 --- a/python/fate_test/fate_test/scripts/flow_test_cli.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import time -import uuid -import click -from datetime import timedelta -from pathlib import Path -from ruamel import yaml - -from fate_test._config import Config -from fate_test._io import LOGGER, echo -from fate_test.scripts._options import SharedOptions -from fate_test.flow_test import flow_rest_api, flow_sdk_api, flow_cli_api, flow_process - - -@click.group(name="flow-test") -def flow_group(): - """ - flow test - """ - ... - - -@flow_group.command("process") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def process(ctx, **kwargs): - """ - flow process test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - try: - start = time.time() - flow_process.run_fate_flow_test(get_role(conf=config_inst)) - echo.echo(f"elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@flow_group.command("rest") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def api(ctx, **kwargs): - """ - flow rest api test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - try: - start = time.time() - flow_rest_api.run_test_api(get_role(conf=config_inst), namespace) - echo.echo(f"elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@flow_group.command("sdk") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def api(ctx, **kwargs): - """ - flow sdk api test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - try: - start = time.time() - flow_sdk_api.run_test_api(get_role(conf=config_inst), namespace) - echo.echo(f"elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@flow_group.command("cli") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def api(ctx, **kwargs): - """ - flow cli api test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - try: - start = time.time() - flow_cli_api.run_test_api(get_role(conf=config_inst), namespace) - echo.echo(f"elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -def get_role(conf: Config): - flow_services = conf.serving_setting['flow_services'][0]['address'] - path = conf.flow_test_config_dir - if isinstance(path, str): - path = Path(path) - config = {} - if path is not None: - with path.open("r") as f: - config.update(yaml.safe_load(f)) - flow_test_template = config['flow_test_template'] - config_json = {'guest_party_id': conf.role['guest'], - 'host_party_id': [conf.role['host'][0]], - 'arbiter_party_id': conf.role['arbiter'], - 'online_serving': conf.serving_setting['serving_setting']['address'], - 'train_conf_path': os.path.abspath(conf.data_base_dir) + flow_test_template['train_conf_path'], - 'train_dsl_path': os.path.abspath(conf.data_base_dir) + flow_test_template['train_dsl_path'], - 'predict_conf_path': os.path.abspath(conf.data_base_dir) + flow_test_template['predict_conf_path'], - 'predict_dsl_path': os.path.abspath(conf.data_base_dir) + flow_test_template['predict_dsl_path'], - 'upload_file_path': os.path.abspath(conf.data_base_dir) + flow_test_template['upload_conf_path'], - 'model_file_path': os.path.abspath(conf.data_base_dir) + flow_test_template['model_conf_path'], - 'server_url': "http://{}/{}".format(flow_services, config['api_version']), - 'train_auc': config['train_auc'], - 'phone_num': config['phone_num'], - 'component_name': config['component_name'], - 'component_is_homo': config.get('component_is_homo', False), - 'serving_setting': conf.serving_setting['serving_setting']['address'], - 'metric_output_path': config['metric_output_path'], - 'model_output_path': config['model_output_path'], - 'cache_directory': conf.cache_directory, - 'data_base_dir': conf.data_base_dir - } - if flow_test_template.get('homo_deploy_path'): - config_json['homo_deploy_path'] = os.path.abspath(conf.data_base_dir) + flow_test_template['homo_deploy_path'] - if flow_test_template.get('homo_deploy_kube_config_path'): - config_json['homo_deploy_kube_config_path'] = os.path.abspath(conf.data_base_dir) + \ - flow_test_template['homo_deploy_kube_config_path'] - return config_json diff --git a/python/fate_test/fate_test/scripts/generate_mock_data.py b/python/fate_test/fate_test/scripts/generate_mock_data.py deleted file mode 100644 index 3170d8e981..0000000000 --- a/python/fate_test/fate_test/scripts/generate_mock_data.py +++ /dev/null @@ -1,345 +0,0 @@ -import hashlib -import json -import os -import random -import threading -import sys -import time -import uuid -import functools -import pandas as pd -import numpy as np - -from fate_test._config import Config -from fate_test._io import echo, LOGGER - - -def import_fate(): - from fate_arch import storage - from fate_flow.utils import data_utils - from fate_arch import session - from fate_arch.storage import StorageEngine - from fate_arch.common.conf_utils import get_base_config - from fate_arch.storage import EggRollStoreType - return storage, data_utils, session, StorageEngine, get_base_config, EggRollStoreType - - -storage, data_utils, session, StorageEngine, get_base_config, EggRollStoreType = import_fate() - - -sys.setrecursionlimit(1000000) - - -class data_progress: - def __init__(self, down_load, time_start): - self.time_start = time_start - self.down_load = down_load - self.time_percent = 0 - self.switch = True - - def set_switch(self, switch): - self.switch = switch - - def get_switch(self): - return self.switch - - def set_time_percent(self, time_percent): - self.time_percent = time_percent - - def get_time_percent(self): - return self.time_percent - - def progress(self, percent): - if percent > 100: - percent = 100 - end = time.time() - if percent != 100: - print(f"\r{self.down_load} %.f%s [%s] running" % (percent, '%', self.timer(end - self.time_start)), - flush=True, end='') - else: - print(f"\r{self.down_load} %.f%s [%s] success" % (percent, '%', self.timer(end - self.time_start)), - flush=True, end='') - - @staticmethod - def timer(times): - hours, rem = divmod(times, 3600) - minutes, seconds = divmod(rem, 60) - return "{:0>2}:{:0>2}:{:0>2}".format(int(hours), int(minutes), int(seconds)) - - -def remove_file(path): - os.remove(path) - - -def id_encryption(encryption_type, start_num, end_num): - if encryption_type == 'md5': - return [hashlib.md5(bytes(str(value), encoding='utf-8')).hexdigest() for value in range(start_num, end_num)] - elif encryption_type == 'sha256': - return [hashlib.sha256(bytes(str(value), encoding='utf-8')).hexdigest() for value in range(start_num, end_num)] - else: - return [str(value) for value in range(start_num, end_num)] - - -def get_big_data(guest_data_size, host_data_size, guest_feature_num, host_feature_num, include_path, host_data_type, - conf: Config, encryption_type, match_rate, sparsity, force, split_host, output_path, parallelize): - global big_data_dir - - def list_tag_value(feature_nums, head): - # data = '' - # for f in range(feature_nums): - # data += head[f] + ':' + str(round(np.random.randn(), 4)) + ";" - # return data[:-1] - return ";".join([head[k] + ':' + str(round(v, 4)) for k, v in enumerate(np.random.randn(feature_nums))]) - - def list_tag(feature_nums, data_list): - data = '' - for f in range(feature_nums): - data += random.choice(data_list) + ";" - return data[:-1] - - def _generate_tag_value_data(data_path, start_num, end_num, feature_nums, progress): - data_num = end_num - start_num - section_data_size = round(data_num / 100) - iteration = round(data_num / section_data_size) - head = ['x' + str(i) for i in range(feature_nums)] - for batch in range(iteration + 1): - progress.set_time_percent(batch) - output_data = pd.DataFrame(columns=["id"]) - if section_data_size * (batch + 1) <= data_num: - output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, - section_data_size * (batch + 1) + start_num) - slicing_data_size = section_data_size - elif section_data_size * batch < data_num: - output_data['id'] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) - slicing_data_size = data_num - section_data_size * batch - else: - break - feature = [list_tag_value(feature_nums, head) for i in range(slicing_data_size)] - output_data['feature'] = feature - output_data.to_csv(data_path, mode='a+', index=False, header=False) - - def _generate_dens_data(data_path, start_num, end_num, feature_nums, label_flag, progress): - if label_flag: - head_1 = ['id', 'y'] - else: - head_1 = ['id'] - data_num = end_num - start_num - head_2 = ['x' + str(i) for i in range(feature_nums)] - df_data_1 = pd.DataFrame(columns=head_1) - head_data = pd.DataFrame(columns=head_1 + head_2) - head_data.to_csv(data_path, mode='a+', index=False) - section_data_size = round(data_num / 100) - iteration = round(data_num / section_data_size) - for batch in range(iteration + 1): - progress.set_time_percent(batch) - if section_data_size * (batch + 1) <= data_num: - df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, - section_data_size * (batch + 1) + start_num) - slicing_data_size = section_data_size - elif section_data_size * batch < data_num: - df_data_1 = pd.DataFrame(columns=head_1) - df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) - slicing_data_size = data_num - section_data_size * batch - else: - break - if label_flag: - df_data_1["y"] = [round(np.random.random()) for x in range(slicing_data_size)] - feature = np.random.randint(-10000, 10000, size=[slicing_data_size, feature_nums]) / 10000 - df_data_2 = pd.DataFrame(feature, columns=head_2) - output_data = pd.concat([df_data_1, df_data_2], axis=1) - output_data.to_csv(data_path, mode='a+', index=False, header=False) - - def _generate_tag_data(data_path, start_num, end_num, feature_nums, sparsity, progress): - data_num = end_num - start_num - section_data_size = round(data_num / 100) - iteration = round(data_num / section_data_size) - valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))] - data = list(map(str, valid_set)) - for batch in range(iteration + 1): - progress.set_time_percent(batch) - output_data = pd.DataFrame(columns=["id"]) - if section_data_size * (batch + 1) <= data_num: - output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, - section_data_size * (batch + 1) + start_num) - slicing_data_size = section_data_size - elif section_data_size * batch < data_num: - output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) - slicing_data_size = data_num - section_data_size * batch - else: - break - feature = [list_tag(feature_nums, data_list=data) for i in range(slicing_data_size)] - output_data['feature'] = feature - output_data.to_csv(data_path, mode='a+', index=False, header=False) - - def _generate_parallelize_data(start_num, end_num, feature_nums, table_name, namespace, label_flag, data_type, - partition, progress): - def expand_id_range(k, v): - if label_flag: - return [(id_encryption(encryption_type, ids, ids + 1)[0], - ",".join([str(round(np.random.random()))] + [str(round(i, 4)) for i in np.random.randn(v)])) - for ids in range(int(k), min(step + int(k), end_num))] - else: - if data_type == 'tag': - valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))] - data = list(map(str, valid_set)) - return [(id_encryption(encryption_type, ids, ids + 1)[0], - ";".join([random.choice(data) for i in range(int(v))])) - for ids in range(int(k), min(step + int(k), data_num))] - - elif data_type == 'tag_value': - return [(id_encryption(encryption_type, ids, ids + 1)[0], - ";".join([f"x{i}" + ':' + str(round(i, 4)) for i in np.random.randn(v)])) - for ids in range(int(k), min(step + int(k), data_num))] - elif data_type == 'dense': - return [(id_encryption(encryption_type, ids, ids + 1)[0], - ",".join([str(round(i, 4)) for i in np.random.randn(v)])) - for ids in range(int(k), min(step + int(k), data_num))] - data_num = end_num - start_num - step = 10000 if data_num > 10000 else int(data_num / 10) - table_list = [(f"{i * step}", f"{feature_nums}") for i in range(int(data_num / step) + start_num)] - table = sess.computing.parallelize(table_list, partition=partition, include_key=True) - table = table.flatMap(functools.partial(expand_id_range)) - if label_flag: - schema = {"sid": "id", "header": ",".join(["y"] + [f"x{i}" for i in range(feature_nums)])} - else: - schema = {"sid": "id", "header": ",".join([f"x{i}" for i in range(feature_nums)])} - if data_type != "dense": - schema = None - - h_table = sess.get_table(name=table_name, namespace=namespace) - if h_table: - h_table.destroy() - - table_meta = sess.persistent(computing_table=table, name=table_name, namespace=namespace, schema=schema) - - storage_session = sess.storage() - s_table = storage_session.get_table(namespace=table_meta.get_namespace(), name=table_meta.get_name()) - if s_table.count() == data_num: - progress.set_time_percent(100) - from fate_flow.manager.data_manager import DataTableTracker - DataTableTracker.create_table_tracker( - table_name=table_name, - table_namespace=namespace, - entity_info={} - ) - - def data_save(data_info, table_names, namespaces, partition_list): - data_count = 0 - for idx, data_name in enumerate(data_info.keys()): - label_flag = True if 'guest' in data_info[data_name] else False - data_type = 'dense' if 'guest' in data_info[data_name] else host_data_type - if split_host and ('host' in data_info[data_name]): - host_end_num = int(np.ceil(host_data_size / len(data_info))) * (data_count + 1) if np.ceil( - host_data_size / len(data_info)) * (data_count + 1) <= host_data_size else host_data_size - host_start_num = int(np.ceil(host_data_size / len(data_info))) * data_count - data_count += 1 - else: - host_end_num = host_data_size - host_start_num = 0 - out_path = os.path.join(str(big_data_dir), data_name) - if os.path.exists(out_path) and os.path.isfile(out_path) and not parallelize: - if force: - remove_file(out_path) - else: - echo.echo('{} Already exists'.format(out_path)) - continue - data_i = (idx + 1) / len(data_info) - downLoad = f'dataget [{"#" * int(24 * data_i)}{"-" * (24 - int(24 * data_i))}] {idx + 1}/{len(data_info)}' - start = time.time() - progress = data_progress(downLoad, start) - thread = threading.Thread(target=run, args=[progress]) - thread.start() - - try: - if 'guest' in data_info[data_name]: - if not parallelize: - _generate_dens_data(out_path, guest_start_num, guest_end_num, - guest_feature_num, label_flag, progress) - else: - _generate_parallelize_data( - guest_start_num, - guest_end_num, - guest_feature_num, - table_names[idx], - namespaces[idx], - label_flag, - data_type, - partition_list[idx], - progress) - else: - if data_type == 'tag' and not parallelize: - _generate_tag_data(out_path, host_start_num, host_end_num, host_feature_num, sparsity, progress) - elif data_type == 'tag_value' and not parallelize: - _generate_tag_value_data(out_path, host_start_num, host_end_num, host_feature_num, progress) - elif data_type == 'dense' and not parallelize: - _generate_dens_data(out_path, host_start_num, host_end_num, - host_feature_num, label_flag, progress) - elif parallelize: - _generate_parallelize_data( - host_start_num, - host_end_num, - host_feature_num, - table_names[idx], - namespaces[idx], - label_flag, - data_type, - partition_list[idx], - progress) - progress.set_switch(False) - time.sleep(1) - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - progress.set_switch(False) - echo.stdout_newline() - - def run(p): - while p.get_switch(): - time.sleep(1) - p.progress(p.get_time_percent()) - - if not match_rate > 0 or not match_rate <= 1: - raise Exception(f"The value is between (0-1), Please check match_rate:{match_rate}") - guest_start_num = host_data_size - int(guest_data_size * match_rate) - guest_end_num = guest_start_num + guest_data_size - - if os.path.isfile(include_path): - with include_path.open("r") as f: - testsuite_config = json.load(f) - else: - raise Exception(f'Input file error, please check{include_path}.') - try: - if output_path is not None: - big_data_dir = os.path.abspath(output_path) - else: - big_data_dir = os.path.abspath(conf.cache_directory) - except Exception: - raise Exception('{}path does not exist'.format(big_data_dir)) - date_set = {} - table_name_list = [] - table_namespace_list = [] - partition_list = [] - for upload_dict in testsuite_config.get('data'): - date_set[os.path.basename(upload_dict.get('file'))] = upload_dict.get('role') - table_name_list.append(upload_dict.get('table_name')) - table_namespace_list.append(upload_dict.get('namespace')) - partition_list.append(upload_dict.get('partition', 8)) - - if parallelize: - with session.Session() as sess: - session_id = str(uuid.uuid1()) - sess.init_computing(session_id) - data_save( - data_info=date_set, - table_names=table_name_list, - namespaces=table_namespace_list, - partition_list=partition_list) - else: - data_save( - data_info=date_set, - table_names=table_name_list, - namespaces=table_namespace_list, - partition_list=partition_list) - echo.echo(f'Data storage address, please check{big_data_dir}') diff --git a/python/fate_test/fate_test/scripts/op_test/__init__.py b/python/fate_test/fate_test/scripts/op_test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate_test/fate_test/scripts/op_test/fate_he_performance_test.py b/python/fate_test/fate_test/scripts/op_test/fate_he_performance_test.py deleted file mode 100644 index 315eb6c6d6..0000000000 --- a/python/fate_test/fate_test/scripts/op_test/fate_he_performance_test.py +++ /dev/null @@ -1,79 +0,0 @@ -import numpy as np -from prettytable import PrettyTable, ORGMODE -from fate_test.scripts.op_test.performance_assess import Metric -from operator import add, mul - - -class PaillierAssess(object): - def __init__(self, method, data_num, test_round): - from federatedml.secureprotol.fate_paillier import PaillierKeypair - self.public_key, self.private_key = PaillierKeypair.generate_keypair() - self.method = method - self.data_num = data_num - self.test_round = test_round - self.float_data_x, self.encrypt_float_data_x, self.int_data_x, self.encrypt_int_data_x = self._get_data() - self.float_data_y, self.encrypt_float_data_y, self.int_data_y, self.encrypt_int_data_y = self._get_data() - - def _get_data(self, type_int=True, type_float=True): - if self.method == "Paillier": - key = self.public_key - else: - key = None - encrypt_float_data = [] - encrypt_int_data = [] - float_data = np.random.uniform(-1e9, 1e9, size=self.data_num) - int_data = np.random.randint(-1000, 1000, size=self.data_num) - if type_float: - for i in float_data: - encrypt_float_data.append(key.encrypt(i)) - if type_int: - for i in int_data: - encrypt_int_data.append(key.encrypt(i)) - return float_data, encrypt_float_data, int_data, encrypt_int_data - - def output_table(self): - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = [self.method, "One time consumption", f"{self.data_num} times consumption", - "relative acc", "log2 acc", "operations per second", "plaintext consumption per second"] - - metric = Metric(data_num=self.data_num, test_round=self.test_round) - - table.add_row(metric.encrypt(self.float_data_x, self.public_key.encrypt)) - decrypt_data = [self.private_key.decrypt(i) for i in self.encrypt_float_data_x] - table.add_row(metric.decrypt(self.encrypt_float_data_x, self.float_data_x, decrypt_data, - self.private_key.decrypt)) - - real_data = list(map(add, self.float_data_x, self.float_data_y)) - encrypt_data = list(map(add, self.encrypt_float_data_x, self.encrypt_float_data_y)) - self.binary_op(table, metric, self.encrypt_float_data_x, self.encrypt_float_data_y, - self.int_data_x, self.int_data_y, real_data, encrypt_data, - add, "float add") - - real_data = list(map(add, self.int_data_x, self.int_data_y)) - encrypt_data = list(map(add, self.encrypt_int_data_x, self.encrypt_int_data_y)) - self.binary_op(table, metric, self.encrypt_int_data_x, self.encrypt_int_data_y, - self.int_data_x, self.int_data_y, real_data, encrypt_data, - add, "int add") - - real_data = list(map(mul, self.float_data_x, self.float_data_y)) - encrypt_data = list(map(mul, self.encrypt_float_data_x, self.float_data_y)) - self.binary_op(table, metric, self.encrypt_float_data_x, self.float_data_y, - self.float_data_x, self.float_data_y, real_data, encrypt_data, - mul, "float mul") - - real_data = list(map(mul, self.int_data_x, self.int_data_y)) - encrypt_data = list(map(mul, self.encrypt_int_data_x, self.int_data_y)) - self.binary_op(table, metric, self.encrypt_int_data_x, self.int_data_y, - self.float_data_x, self.float_data_y, real_data, encrypt_data, - mul, "int mul") - - return table.get_string(title=f"{self.method} Computational performance") - - def binary_op(self, table, metric, encrypt_data_x, encrypt_data_y, raw_data_x, raw_data_y, - real_data, encrypt_data, op, op_name): - decrypt_data = [self.private_key.decrypt(i) for i in encrypt_data] - table.add_row(metric.binary_op(encrypt_data_x, encrypt_data_y, - raw_data_x, raw_data_y, - real_data, decrypt_data, - op, op_name)) diff --git a/python/fate_test/fate_test/scripts/op_test/performance_assess.py b/python/fate_test/fate_test/scripts/op_test/performance_assess.py deleted file mode 100644 index 39c687bcca..0000000000 --- a/python/fate_test/fate_test/scripts/op_test/performance_assess.py +++ /dev/null @@ -1,85 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import time -import numpy as np - - -# Operations -class Metric(object): - def __init__(self, data_num, test_round): - self.operation = None - self.data_num = data_num - self.test_round = test_round - - @staticmethod - def accuracy(rand_data, decrypt_data): - difference = 0 - for x, y in zip(rand_data, decrypt_data): - difference += abs(abs(x) - abs(y)) - abs_acc = abs(difference) / len(rand_data) - difference = 0 - for x, y in zip(rand_data, decrypt_data): - difference += abs(abs(x) - abs(y)) / (1e-100 + max(abs(x), abs(y))) - relative_acc = difference / len(rand_data) - log_acc = -np.log2(relative_acc) if relative_acc != 0 else 0 - - return abs_acc, relative_acc, log_acc - - @staticmethod - def many_call(data_x, unary_op=None, binary_op=None, data_y=None, test_round=1): - if unary_op is not None: - time_start = time.time() - for _ in range(test_round): - _ = list(map(unary_op, data_x)) - final_time = time.time() - time_start - else: - time_start = time.time() - for _ in range(test_round): - _ = list(map(binary_op, data_x, data_y)) - final_time = time.time() - time_start - - return final_time / test_round - - def encrypt(self, data, op): - many_round_encrypt_time = self.many_call(data, unary_op=op, test_round=self.test_round) - single_encrypt_time = many_round_encrypt_time / self.data_num - cals_per_second = self.data_num / many_round_encrypt_time - - return ["encrypt", '%.10f' % single_encrypt_time + 's', '%.10f' % many_round_encrypt_time + 's', "-", "-", - int(cals_per_second), "-"] - - def decrypt(self, encrypt_data, data, decrypt_data, function): - many_round_decrypt_time = self.many_call(encrypt_data, function, test_round=self.test_round) - single_decrypt_time = many_round_decrypt_time / self.data_num - cals_per_second = self.data_num / many_round_decrypt_time - abs_acc, relative_acc, log_acc = self.accuracy(data, decrypt_data) - return ["decrypt", '%.10f' % single_decrypt_time + 's', '%.10f' % many_round_decrypt_time + 's', - relative_acc, log_acc, int(cals_per_second), "-"] - - def binary_op(self, encrypt_data_x, encrypt_data_y, - raw_data_x, raw_data_y, real_ret, decrypt_ret, op, op_name): - many_round_time = self.many_call(data_x=encrypt_data_x, binary_op=op, - data_y=encrypt_data_y, test_round=self.test_round) - single_op_time = many_round_time / self.data_num - cals_per_second = self.data_num / many_round_time - - plaintext_per_second = self.data_num / self.many_call(data_x=raw_data_x, data_y=raw_data_y, - binary_op=op, test_round=self.test_round) - - abs_acc, relative_acc, log_acc = self.accuracy(real_ret, decrypt_ret) - return [op_name, '%.10f' % single_op_time + 's', '%.10f' % many_round_time + 's', - relative_acc, log_acc, int(cals_per_second), int(plaintext_per_second)] diff --git a/python/fate_test/fate_test/scripts/op_test/spdz_conf/__init__.py b/python/fate_test/fate_test/scripts/op_test/spdz_conf/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_conf.json b/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_conf.json deleted file mode 100644 index 8f10af7513..0000000000 --- a/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_conf.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "dsl_version": 2, - "initiator": { - "role": "guest", - "party_id": 9999 - }, - "role": { - "guest": [ - 9999 - ], - "host": [ - 10000 - ] - }, - "component_parameters": { - "common": { - "spdz_test_0": { - "seed": 1, - "data_partition": 4, - "data_lower_bound": -1024, - "data_upper_bound": 1024, - "test_round": 1, - "data_num": 10000 - } - } - } -} diff --git a/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_dsl.json b/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_dsl.json deleted file mode 100644 index 5bfb9d4e83..0000000000 --- a/python/fate_test/fate_test/scripts/op_test/spdz_conf/job_dsl.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "components": { - "spdz_test_0": { - "module": "SPDZTest" - } - } -} diff --git a/python/fate_test/fate_test/scripts/op_test/spdz_test.py b/python/fate_test/fate_test/scripts/op_test/spdz_test.py deleted file mode 100644 index 12cf8cd735..0000000000 --- a/python/fate_test/fate_test/scripts/op_test/spdz_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -import time -from prettytable import PrettyTable, ORGMODE -from flow_sdk.client import FlowClient - - -class SPDZTest(object): - def __init__(self, flow_address, params, conf_path, dsl_path, guest_party_id, host_party_id): - self.client = FlowClient(ip=flow_address.split(":")[0], - port=flow_address.split(":")[1], - version="v1") - - self.dsl = self._get_json_file(dsl_path) - self.conf = self._get_json_file(conf_path) - self.conf["role"] = dict(guest=guest_party_id, host=host_party_id) - self.conf["component_parameters"]["common"]["spdz_test_0"].update(params) - self.conf["initiator"]["party_id"] = guest_party_id[0] - self.guest_party_id = guest_party_id[0] - - @staticmethod - def _get_json_file(path): - with open(path, "r") as fin: - ret = json.loads(fin.read()) - - return ret - - def run(self): - result = self.client.job.submit(config_data=self.conf, dsl_data=self.dsl) - - try: - if 'retcode' not in result or result["retcode"] != 0: - raise ValueError(f"retcode err") - - if "jobId" not in result: - raise ValueError(f"jobID not in result: {result}") - - job_id = result["jobId"] - except ValueError: - raise ValueError("job submit failed, err msg: {}".format(result)) - - while True: - info = self.client.job.query(job_id=job_id, role="guest", party_id=self.guest_party_id) - data = info["data"][0] - status = data["f_status"] - if status == "success": - break - elif status == "failed": - raise ValueError(f"job is failed, jobid is {job_id}") - - time.sleep(1) - - summary = self.client.component.get_summary(job_id=job_id, role="guest", - party_id=self.guest_party_id, - component_name="spdz_test_0") - - summary = summary["data"] - field_name = summary["field_name"] - - tables = [] - for tensor_type in summary["tensor_type"]: - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = field_name - for op_type in summary["op_test_list"]: - table.add_row(summary[tensor_type][op_type]) - - tables.append(table.get_string(title=f"SPDZ {tensor_type} Computational performance")) - - return tables diff --git a/python/fate_test/fate_test/scripts/performance_cli.py b/python/fate_test/fate_test/scripts/performance_cli.py deleted file mode 100644 index 6cf097413a..0000000000 --- a/python/fate_test/fate_test/scripts/performance_cli.py +++ /dev/null @@ -1,365 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -import os -import time -import uuid -from datetime import timedelta -import click -import glob - -from fate_test import _config -from fate_test._client import Clients -from fate_test._config import Config -from fate_test.utils import TxtStyle -from fate_test._flow_client import JobProgress, SubmitJobResponse, QueryJobResponse -from fate_test._io import LOGGER, echo -from prettytable import PrettyTable, ORGMODE -from fate_test._parser import JSON_STRING, Testsuite -from fate_test.scripts._options import SharedOptions -from fate_test.scripts._utils import _load_testsuites, _upload_data, _delete_data, _load_module_from_script, \ - _add_replace_hook - - -@click.command("performance") -@click.option('-t', '--job-type', type=click.Choice(['intersect', 'intersect_multi', 'hetero_lr', 'hetero_sbt']), - help="Select the job type, you can also set through include") -@click.option('-i', '--include', type=click.Path(exists=True), multiple=True, metavar="", - help="include *testsuite.json under these paths") -@click.option('-r', '--replace', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in data/conf/dsl") -@click.option('-m', '--timeout', type=int, default=3600, - help="maximun running time of job") -@click.option('-e', '--max-iter', type=int, help="When the algorithm model is LR, the number of iterations is set") -@click.option('-d', '--max-depth', type=int, - help="When the algorithm model is SecureBoost, set the number of model layers") -@click.option('-nt', '--num-trees', type=int, help="When the algorithm model is SecureBoost, set the number of trees") -@click.option('-p', '--task-cores', type=int, help="processors per node") -@click.option('-uj', '--update-job-parameters', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in conf.job_parameters") -@click.option('-uc', '--update-component-parameters', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in conf.component_parameters") -@click.option('-s', '--storage-tag', type=str, - help="tag for storing performance time consuming, for future comparison") -@click.option('-v', '--history-tag', type=str, multiple=True, - help="Extract performance time consuming from history tags for comparison") -@click.option("--skip-data", is_flag=True, default=False, - help="skip uploading data specified in testsuite") -@click.option("--provider", type=str, - help="Select the fate version, for example: fate@1.7") -@click.option("--disable-clean-data", "clean_data", flag_value=False, default=None) -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def run_task(ctx, job_type, include, replace, timeout, update_job_parameters, update_component_parameters, max_iter, - max_depth, num_trees, task_cores, storage_tag, history_tag, skip_data, clean_data, provider, **kwargs): - """ - Test the performance of big data tasks, alias: bp - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - config_inst = ctx.obj["config"] - config_inst.extend_sid = ctx.obj["extend_sid"] - config_inst.auto_increasing_sid = ctx.obj["auto_increasing_sid"] - namespace = ctx.obj["namespace"] - yes = ctx.obj["yes"] - data_namespace_mangling = ctx.obj["namespace_mangling"] - if clean_data is None: - clean_data = config_inst.clean_data - - def get_perf_template(conf: Config, job_type): - perf_dir = os.path.join(os.path.abspath(conf.perf_template_dir) + '/' + job_type + '/' + "*testsuite.json") - return glob.glob(perf_dir) - - if not include: - include = get_perf_template(config_inst, job_type) - # prepare output dir and json hooks - _add_replace_hook(replace) - - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo("loading testsuites:") - suites = _load_testsuites(includes=include, excludes=tuple(), glob=None, provider=provider) - for i, suite in enumerate(suites): - echo.echo(f"\tdataset({len(suite.dataset)}) dsl jobs({len(suite.jobs)}) {suite.path}") - - if not yes and not click.confirm("running?"): - return - - echo.stdout_newline() - with Clients(config_inst) as client: - - for i, suite in enumerate(suites): - # noinspection PyBroadException - try: - start = time.time() - echo.echo(f"[{i + 1}/{len(suites)}]start at {time.strftime('%Y-%m-%d %X')} {suite.path}", fg='red') - - if not skip_data: - try: - _upload_data(client, suite, config_inst) - except Exception as e: - raise RuntimeError(f"exception occur while uploading data for {suite.path}") from e - - echo.stdout_newline() - try: - time_consuming = _submit_job(client, suite, namespace, config_inst, timeout, update_job_parameters, - storage_tag, history_tag, update_component_parameters, max_iter, - max_depth, num_trees, task_cores) - except Exception as e: - raise RuntimeError(f"exception occur while submit job for {suite.path}") from e - - try: - _run_pipeline_jobs(config_inst, suite, namespace, data_namespace_mangling) - except Exception as e: - raise RuntimeError(f"exception occur while running pipeline jobs for {suite.path}") from e - - echo.echo(f"[{i + 1}/{len(suites)}]elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - if not skip_data and clean_data: - _delete_data(client, suite) - echo.echo(suite.pretty_final_summary(time_consuming), fg='red') - - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception in {suite.path}, exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -def _submit_job(clients: Clients, suite: Testsuite, namespace: str, config: Config, timeout, update_job_parameters, - storage_tag, history_tag, update_component_parameters, max_iter, max_depth, num_trees, task_cores): - # submit jobs - with click.progressbar(length=len(suite.jobs), - label="jobs", - show_eta=False, - show_pos=True, - width=24) as bar: - time_list = [] - for job in suite.jobs_iter(): - start = time.time() - job_progress = JobProgress(job.job_name) - - def _raise(): - exception_id = str(uuid.uuid1()) - job_progress.exception(exception_id) - suite.update_status(job_name=job.job_name, exception_id=exception_id) - echo.file(f"exception({exception_id})") - LOGGER.exception(f"exception id: {exception_id}") - - # noinspection PyBroadException - try: - if max_iter is not None: - job.job_conf.update_component_parameters('max_iter', max_iter) - if max_depth is not None: - job.job_conf.update_component_parameters('max_depth', max_depth) - if num_trees is not None: - job.job_conf.update_component_parameters('num_trees', num_trees) - if task_cores is not None: - job.job_conf.update_job_common_parameters(task_cores=task_cores) - job.job_conf.update(config.parties, timeout, update_job_parameters, update_component_parameters) - except Exception: - _raise() - continue - - def update_bar(n_step): - bar.item_show_func = lambda x: job_progress.show() - time.sleep(0.1) - bar.update(n_step) - - update_bar(1) - - def _call_back(resp: SubmitJobResponse): - if isinstance(resp, SubmitJobResponse): - job_progress.submitted(resp.job_id) - echo.file(f"[jobs] {resp.job_id} ", nl=False) - suite.update_status(job_name=job.job_name, job_id=resp.job_id) - - if isinstance(resp, QueryJobResponse): - job_progress.running(resp.status, resp.progress) - - update_bar(0) - - # noinspection PyBroadException - try: - response = clients["guest_0"].submit_job(job=job, callback=_call_back) - - # noinspection PyBroadException - try: - # add notes - notes = f"{job.job_name}@{suite.path}@{namespace}" - for role, party_id_list in job.job_conf.role.items(): - for i, party_id in enumerate(party_id_list): - clients[f"{role}_{i}"].add_notes(job_id=response.job_id, role=role, party_id=party_id, - notes=notes) - except Exception: - pass - except Exception: - _raise() - else: - job_progress.final(response.status) - suite.update_status(job_name=job.job_name, status=response.status.status) - if response.status.is_success(): - if suite.model_in_dep(job.job_name): - dependent_jobs = suite.get_dependent_jobs(job.job_name) - for predict_job in dependent_jobs: - model_info, table_info, cache_info, model_loader_info = None, None, None, None - for i in _config.deps_alter[predict_job.job_name]: - if isinstance(i, dict): - name = i.get('name') - data_pre = i.get('data') - - if 'data_deps' in _config.deps_alter[predict_job.job_name]: - roles = list(data_pre.keys()) - table_info, hierarchy = [], [] - for role_ in roles: - role, index = role_.split("_") - input_ = data_pre[role_] - for data_input, cpn in input_.items(): - try: - table_name = clients["guest_0"].output_data_table( - job_id=response.job_id, - role=role, - party_id=config.role[role][int(index)], - component_name=cpn) - except Exception: - _raise() - if predict_job.job_conf.dsl_version == 2: - hierarchy.append([role, index, data_input]) - table_info.append({'table': table_name}) - else: - hierarchy.append([role, 'args', 'data']) - table_info.append({data_input: [table_name]}) - table_info = {'hierarchy': hierarchy, 'table_info': table_info} - if 'model_deps' in _config.deps_alter[predict_job.job_name]: - if predict_job.job_conf.dsl_version == 2: - # noinspection PyBroadException - try: - model_info = clients["guest_0"].deploy_model( - model_id=response.model_info["model_id"], - model_version=response.model_info["model_version"], - dsl=predict_job.job_dsl.as_dict()) - except Exception: - _raise() - else: - model_info = response.model_info - if 'cache_deps' in _config.deps_alter[predict_job.job_name]: - cache_dsl = predict_job.job_dsl.as_dict() - cache_info = [] - for cpn in cache_dsl.get("components").keys(): - if "CacheLoader" in cache_dsl.get("components").get(cpn).get("module"): - cache_info.append({cpn: {'job_id': response.job_id}}) - cache_info = {'hierarchy': [""], 'cache_info': cache_info} - if 'model_loader_deps' in _config.deps_alter[predict_job.job_name]: - model_loader_dsl = predict_job.job_dsl.as_dict() - model_loader_info = [] - for cpn in model_loader_dsl.get("components").keys(): - if "ModelLoader" in model_loader_dsl.get("components").get(cpn).get("module"): - model_loader_info.append({cpn: response.model_info}) - model_loader_info = {'hierarchy': [""], 'model_loader_info': model_loader_info} - - suite.feed_dep_info(predict_job, name, model_info=model_info, table_info=table_info, - cache_info=cache_info, model_loader_info=model_loader_info) - suite.remove_dependency(job.job_name) - update_bar(0) - time_consuming = time.time() - start - performance_dir = "/".join( - [os.path.join(os.path.abspath(config.cache_directory), 'benchmark_history', "performance.json")]) - fate_version = clients["guest_0"].get_version() - if history_tag: - history_tag = ["_".join([i, job.job_name]) for i in history_tag] - comparison_quality(job.job_name, history_tag, performance_dir, time_consuming) - if storage_tag: - storage_tag = "_".join(['FATE', fate_version, storage_tag, job.job_name]) - save_quality(storage_tag, performance_dir, time_consuming) - echo.stdout_newline() - time_list.append(time_consuming) - return [str(int(i)) + "s" for i in time_list] - - -def _run_pipeline_jobs(config: Config, suite: Testsuite, namespace: str, data_namespace_mangling: bool): - # pipeline demo goes here - job_n = len(suite.pipeline_jobs) - for i, pipeline_job in enumerate(suite.pipeline_jobs): - echo.echo(f"Running [{i + 1}/{job_n}] job: {pipeline_job.job_name}") - - def _raise(err_msg, status="failed"): - exception_id = str(uuid.uuid1()) - suite.update_status(job_name=job_name, exception_id=exception_id, status=status) - echo.file(f"exception({exception_id}), error message:\n{err_msg}") - # LOGGER.exception(f"exception id: {exception_id}") - - job_name, script_path = pipeline_job.job_name, pipeline_job.script_path - mod = _load_module_from_script(script_path) - try: - if data_namespace_mangling: - try: - mod.main(config=config, namespace=f"_{namespace}") - suite.update_status(job_name=job_name, status="success") - except Exception as e: - _raise(e) - continue - else: - try: - mod.main(config=config) - suite.update_status(job_name=job_name, status="success") - except Exception as e: - _raise(e) - continue - except Exception as e: - _raise(e, status="not submitted") - continue - - -def comparison_quality(group_name, history_tags, history_info_dir, time_consuming): - assert os.path.exists(history_info_dir), f"Please check the {history_info_dir} Is it deleted" - with open(history_info_dir, 'r') as f: - benchmark_quality = json.load(f, object_hook=dict) - benchmark_performance = {} - for history_tag in history_tags: - for tag in benchmark_quality: - if '_'.join(tag.split("_")[2:]) == history_tag: - benchmark_performance[tag] = benchmark_quality[tag] - if benchmark_performance is not None: - benchmark_performance[group_name] = time_consuming - - table = PrettyTable() - table.set_style(ORGMODE) - table.field_names = ["Script Model Name", "time consuming"] - for script_model_name in benchmark_performance: - table.add_row([f"{script_model_name}"] + - [f"{TxtStyle.FIELD_VAL}{benchmark_performance[script_model_name]}{TxtStyle.END}"]) - print("\n") - print(table.get_string(title=f"{TxtStyle.TITLE}Performance comparison results{TxtStyle.END}")) - print("#" * 60) - - -def save_quality(storage_tag, save_dir, time_consuming): - os.makedirs(os.path.dirname(save_dir), exist_ok=True) - if os.path.exists(save_dir): - with open(save_dir, 'r') as f: - benchmark_quality = json.load(f, object_hook=dict) - else: - benchmark_quality = {} - benchmark_quality.update({storage_tag: time_consuming}) - try: - with open(save_dir, 'w') as fp: - json.dump(benchmark_quality, fp, indent=2) - print("\n" + "Storage successful, please check: ", save_dir) - except Exception: - print("\n" + "Storage failed, please check: ", save_dir) diff --git a/python/fate_test/fate_test/scripts/pipeline_conversion_cli.py b/python/fate_test/fate_test/scripts/pipeline_conversion_cli.py deleted file mode 100644 index 1c3521b67d..0000000000 --- a/python/fate_test/fate_test/scripts/pipeline_conversion_cli.py +++ /dev/null @@ -1,361 +0,0 @@ -import copy -import os -import shutil -import sys -import time -import uuid -import json -import click -import importlib - -from fate_test._config import Config -from fate_test._io import LOGGER, echo -from fate_test.scripts._options import SharedOptions - - -@click.group(name="convert") -def convert_group(): - """ - Converting pipeline files to dsl v2 - """ - ... - - -@convert_group.command("pipeline-to-dsl") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), multiple=True, metavar="", - help="include *pipeline.py under these paths") -@click.option('-o', '--output-path', type=click.Path(exists=True), help="DSL output path, default to *pipeline.py path") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def to_dsl(ctx, include, output_path, **kwargs): - """ - This command will run pipeline, make sure data is uploaded - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"converting namespace: {namespace}", fg='red') - for path in include: - echo.echo(f"pipeline path: {os.path.abspath(path)}") - if not yes and not click.confirm("running?"): - return - config_yaml_file = './examples/config.yaml' - temp_file_path = f'./logs/{namespace}/temp_pipeline.py' - - for i in include: - try: - convert(i, temp_file_path, config_yaml_file, output_path, config_inst) - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - echo.farewell() - echo.echo(f"converting namespace: {namespace}", fg='red') - - -@convert_group.command("pipeline-testsuite-to-dsl-testsuite") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), metavar="", - help="include is the pipeline test folder containing *testsuite.py") -@click.option('-t', '--template-path', required=False, type=click.Path(exists=True), metavar="", - help="specify the test template to use") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def to_testsuite(ctx, include, template_path, **kwargs): - """ - convert pipeline testsuite to dsl testsuite - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - echo.welcome() - if not os.path.isdir(include): - raise Exception("Please fill in a folder.") - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo(f"pipeline path: {os.path.abspath(include)}") - if not yes and not click.confirm("running?"): - return - input_path = os.path.abspath(include) - input_list = [input_path] - i = 0 - while i < len(input_list): - dirs = os.listdir(input_list[i]) - for d in dirs: - if os.path.isdir(d): - input_list.append(d) - i += 1 - - for file_path in input_list: - try: - module_name = os.path.basename(file_path) - do_generated(file_path, module_name, template_path, config_inst) - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - echo.farewell() - echo.echo(f"converting namespace: {namespace}", fg='red') - - -def make_temp_pipeline(pipeline_file, temp_file_path, folder_name): - def _conf_file_update(_line, k, end, conf_file=None): - if ")" in _line[0]: - if conf_file is None: - conf_file = os.path.abspath(folder_name + "/" + _line[0].replace("'", "").replace('"', ""). - replace(")", "").replace(":", "").replace("\n", "")) - _line = k + conf_file + end - else: - if conf_file is None: - conf_file = os.path.abspath(folder_name + "/" + _line[0].replace('"', "")) - _line = k + conf_file + '",' + _line[-1] - - return conf_file, _line - - def _get_conf_file(_lines): - param_default = False - conf_file = None - for _line in _lines: - if "--param" in _line or param_default: - if "default" in _line: - _line_start = _line.split("default=") - _line_end = _line_start[1].split(",") - conf_file, _ = _conf_file_update(_line_end, 'default="', '")') - param_default = False - else: - param_default = True - return conf_file - - code_list = [] - with open(pipeline_file, 'r') as f: - lines = f.readlines() - start_main = False - has_returned = False - space_num = 0 - conf_file_dir = _get_conf_file(lines) - for line in lines: - if line is None: - continue - elif "def main" in line: - for char in line: - if char.isspace(): - space_num += 1 - else: - break - start_main = True - if "param=" in line: - line_start = line.split("param=") - line_end = line_start[1].split(",") - conf_file_dir, line = _conf_file_update(line_end, 'param="', '")', conf_file_dir) - line = line_start[0] + line - elif start_main and "def " in line and not has_returned: - code_list.append(" " * (space_num + 4) + "return pipeline\n") - start_main = False - elif start_main and "return " in line: - code_list.append(" " * (space_num + 4) + "return pipeline\n") - start_main = False - continue - elif start_main and 'if __name__ ==' in line: - code_list.append(" " * (space_num + 4) + "return pipeline\n") - start_main = False - code_list.append(line) - if start_main: - code_list.append(" " * (space_num + 4) + "return pipeline\n") - - with open(temp_file_path, 'w') as f: - f.writelines(code_list) - - -def convert(pipeline_file, temp_file_path, config_yaml_file, output_path, config: Config): - folder_name, file_name = os.path.split(pipeline_file) - if output_path is not None: - folder_name = output_path - echo.echo(f"folder_name: {os.path.abspath(folder_name)}, file_name: {file_name}") - conf_name = file_name.replace('.py', '_conf.json') - dsl_name = file_name.replace('.py', '_dsl.json') - conf_name = os.path.join(folder_name, conf_name) - dsl_name = os.path.join(folder_name, dsl_name) - - make_temp_pipeline(pipeline_file, temp_file_path, folder_name) - additional_path = os.path.realpath(os.path.join(os.path.curdir, pipeline_file, os.pardir, os.pardir)) - if additional_path not in sys.path: - sys.path.append(additional_path) - loader = importlib.machinery.SourceFileLoader("main", str(temp_file_path)) - spec = importlib.util.spec_from_loader(loader.name, loader) - mod = importlib.util.module_from_spec(spec) - loader.exec_module(mod) - my_pipeline = mod.main(os.path.join(config.data_base_dir, config_yaml_file)) - conf = my_pipeline.get_train_conf() - dsl = my_pipeline.get_train_dsl() - os.remove(temp_file_path) - - with open(conf_name, 'w') as f: - json.dump(conf, f, indent=4) - echo.echo('conf name is {}'.format(os.path.abspath(conf_name))) - with open(dsl_name, 'w') as f: - json.dump(dsl, f, indent=4) - echo.echo('dsl name is {}'.format(os.path.abspath(dsl_name))) - - -def insert_extract_code(file_path): - code_lines = [] - code = \ - """ -import json -import os -def extract(my_pipeline, file_name, output_path='dsl_testsuite'): - out_name = file_name.split('/')[-1] - out_name = out_name.replace('pipeline-', '').replace('.py', '').replace('-', '_') - conf = my_pipeline.get_train_conf() - dsl = my_pipeline.get_train_dsl() - cur_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - conf_name = os.path.join(cur_dir, output_path, f"{out_name}_conf.json") - dsl_name = os.path.join(cur_dir, output_path, f"{out_name}_dsl.json") - json.dump(conf, open(conf_name, 'w'), indent=4) - json.dump(dsl, open(dsl_name, 'w'), indent=4) - """ - - code_lines.append(code) - screen_keywords = [".predict(", ".fit(", ".deploy_component(", "predict_pipeline ", - "predict_pipeline."] - continue_to_screen = False - has_return = False - - with open(file_path, 'r') as f: - lines = f.readlines() - for l in lines: - if ".predict(" in l or ".fit(" in l: - code_lines.append(f"# {l}") - - elif 'if __name__ == "__main__":' in l: - if not has_return: - code_lines.append(" extract(pipeline, __file__)\n") - code_lines.append(l) - - elif 'return' in l: - code_lines.append(" extract(pipeline, __file__)\n") - # code_lines.append(l) - has_return = True - - elif "get_summary()" in l: - continue - elif continue_to_screen: - code_lines.append(f"# {l}") - if ")" in l: - continue_to_screen = False - else: - should_append = True - for key_word in screen_keywords: - if key_word in l: - code_lines.append(f"# {l}") - should_append = False - if ")" not in l: - continue_to_screen = True - if should_append: - code_lines.append(l) - - return code_lines - - -def get_testsuite_file(testsuite_file_path): - echo.echo(f"testsuite_file_path: {testsuite_file_path}") - with open(testsuite_file_path, 'r', encoding='utf-8') as load_f: - testsuite_json = json.load(load_f) - if "tasks" in testsuite_json: - del testsuite_json["tasks"] - if "pipeline_tasks" in testsuite_json: - del testsuite_json["pipeline_tasks"] - return testsuite_json - - -def do_generated(file_path, fold_name, template_path, config: Config): - yaml_file = os.path.join(config.data_base_dir, "./examples/config.yaml") - PYTHONPATH = os.environ.get('PYTHONPATH') + ":" + str(config.data_base_dir) - os.environ['PYTHONPATH'] = PYTHONPATH - if not os.path.isdir(file_path): - return - files = os.listdir(file_path) - if template_path is None: - for f in files: - if "testsuite" in f and "generated_testsuite" not in f: - template_path = os.path.join(file_path, f) - break - if template_path is None: - return - - suite_json = get_testsuite_file(template_path) - pipeline_suite = copy.deepcopy(suite_json) - suite_json["tasks"] = {} - pipeline_suite["pipeline_tasks"] = {} - replaced_path = os.path.join(file_path, 'replaced_code') - generated_path = os.path.join(file_path, 'dsl_testsuite') - - if not os.path.exists(replaced_path): - os.system('mkdir {}'.format(replaced_path)) - - if not os.path.exists(generated_path): - os.system('mkdir {}'.format(generated_path)) - - for f in files: - if not f.startswith("pipeline"): - continue - echo.echo(f) - task_name = f.replace(".py", "") - task_name = "-".join(task_name.split('-')[1:]) - pipeline_suite["pipeline_tasks"][task_name] = { - "script": f - } - f_path = os.path.join(file_path, f) - code_str = insert_extract_code(f_path) - pipeline_file_path = os.path.join(replaced_path, f) - open(pipeline_file_path, 'w').writelines(code_str) - - exe_files = os.listdir(replaced_path) - fail_job_count = 0 - task_type_list = [] - exe_conf_file = None - exe_dsl_file = None - for i, f in enumerate(exe_files): - abs_file = os.path.join(replaced_path, f) - echo.echo('\n' + '[{}/{}] executing {}'.format(i + 1, len(exe_files), abs_file), fg='red') - result = os.system(f"python {abs_file} -config {yaml_file}") - if not result: - time.sleep(3) - conf_files = os.listdir(generated_path) - f_dsl = {"_".join(f.split('_')[:-1]): f for f in conf_files if 'dsl.json' in f} - f_conf = {"_".join(f.split('_')[:-1]): f for f in conf_files if 'conf.json' in f} - for task_type, dsl_file in f_dsl.items(): - if task_type not in task_type_list: - exe_dsl_file = dsl_file - task_type_list.append(task_type) - exe_conf_file = f_conf[task_type] - suite_json['tasks'][task_type] = { - "conf": exe_conf_file, - "dsl": exe_dsl_file - } - echo.echo('conf name is {}'.format(os.path.join(file_path, "dsl_testsuite", exe_conf_file))) - echo.echo('dsl name is {}'.format(os.path.join(file_path, "dsl_testsuite", exe_dsl_file))) - else: - echo.echo('profile generation failed') - fail_job_count += 1 - - suite_path = os.path.join(generated_path, f"{fold_name}_testsuite.json") - with open(suite_path, 'w', encoding='utf-8') as json_file: - json.dump(suite_json, json_file, ensure_ascii=False, indent=4) - - suite_path = os.path.join(file_path, f"{fold_name}_pipeline_testsuite.json") - with open(suite_path, 'w', encoding='utf-8') as json_file: - json.dump(pipeline_suite, json_file, ensure_ascii=False, indent=4) - - shutil.rmtree(replaced_path) - if not fail_job_count: - echo.echo("Generate testsuite and dsl&conf finished!") - else: - echo.echo("Generate testsuite and dsl&conf finished! {} failures".format(fail_job_count)) diff --git a/python/fate_test/fate_test/scripts/quick_test_cli.py b/python/fate_test/fate_test/scripts/quick_test_cli.py deleted file mode 100644 index 61caec5e5c..0000000000 --- a/python/fate_test/fate_test/scripts/quick_test_cli.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import subprocess -import click -from fate_test._config import Config -from fate_test._io import LOGGER, echo -from fate_test.scripts._options import SharedOptions - - -@click.group(name="unittest") -def unittest_group(): - """ - unit test - """ - ... - - -@unittest_group.command("federatedml") -@click.option('-i', '--include', type=click.Path(exists=True), multiple=True, metavar="", - help="Specify federatedml test units for testing") -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def unit_test(ctx, include, **kwargs): - """ - federatedml unit test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - config_inst = ctx.obj["config"] - yes = ctx.obj["yes"] - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - - error_log_file = f"./logs/{namespace}/error_test.log" - os.makedirs(os.path.dirname(error_log_file), exist_ok=True) - run_test(includes=include, conf=config_inst, error_log_file=error_log_file) - - -def run_test(includes, conf: Config, error_log_file): - def error_log(stdout): - if stdout is None: - return os.path.abspath(error_log_file) - with open(error_log_file, "a") as f: - f.write(stdout) - - def run_test(file): - global failed_count - echo.echo("start to run test {}".format(file)) - try: - subp = subprocess.Popen(["python", file], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - stdout, stderr = subp.communicate() - stdout = stdout.decode("utf-8") - echo.echo(stdout) - if "FAILED" in stdout: - failed_count += 1 - error_log(stdout=f"error sequence {failed_count}: {file}") - error_log(stdout=stdout) - except Exception: - return - - def traverse_folder(file_fullname): - if os.path.isfile(file_fullname): - if "_test.py" in file_fullname and "ftl" not in file_fullname: - run_test(file_fullname) - else: - for file in os.listdir(file_fullname): - file_fullname_new = os.path.join(file_fullname, file) - if os.path.isdir(file_fullname_new): - traverse_folder(file_fullname_new) - if "_test.py" in file and ("/test" in file_fullname or "tests" in file_fullname): - if "ftl" in file_fullname_new: - continue - else: - run_test(file_fullname_new) - - global failed_count - failed_count = 0 - fate_base = conf.fate_base - ml_dir = os.path.join(fate_base, "python/federatedml") - PYTHONPATH = os.environ.get('PYTHONPATH') + ":" + os.path.join(fate_base, "python") - os.environ['PYTHONPATH'] = PYTHONPATH - if len(includes) == 0: - traverse_folder(ml_dir) - else: - ml_dir = includes - for v in ml_dir: - traverse_folder(os.path.abspath(v)) - - echo.echo(f"there are {failed_count} failed test") - if failed_count > 0: - print('Please check the error content: {}'.format(error_log(None))) diff --git a/python/fate_test/fate_test/scripts/secure_protocol_cli.py b/python/fate_test/fate_test/scripts/secure_protocol_cli.py deleted file mode 100755 index dd92241901..0000000000 --- a/python/fate_test/fate_test/scripts/secure_protocol_cli.py +++ /dev/null @@ -1,92 +0,0 @@ -import click -import os -from fate_test._io import LOGGER, echo -from fate_test.scripts._options import SharedOptions -from fate_test.scripts.op_test.fate_he_performance_test import PaillierAssess -from fate_test.scripts.op_test.spdz_test import SPDZTest - - -@click.group(name="secure_protocol") -def secure_protocol_group(): - """ - secureprotol test - """ - ... - - -@secure_protocol_group.command("paillier") -@click.option("-round", "--test-round", type=int, help="", default=1) -@click.option("-num", "--data-num", type=int, help="", default=10000) -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def paillier_test(ctx, data_num, test_round, **kwargs): - """ - paillier - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - - for method in ["Paillier"]: - assess_table = PaillierAssess(method=method, data_num=data_num, test_round=test_round) - table = assess_table.output_table() - echo.echo(table) - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -@secure_protocol_group.command("spdz") -@click.option("-round", "--test-round", type=int, help="", default=1) -@click.option("-num", "--data-num", type=int, help="", default=10000) -@click.option("-partition", "--data-partition", type=int, help="", default=4) -@click.option("-lower_bound", "--data-lower-bound", type=int, help="", default=-1e9) -@click.option("-upper_bound", "--data-upper-bound", type=int, help="", default=1e9) -@click.option("-seed", "--seed", type=int, help="", default=123) -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def spdz_test(ctx, data_num, seed, data_partition, test_round, - data_lower_bound, data_upper_bound, **kwargs): - """ - spdz_test - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - namespace = ctx.obj["namespace"] - yes = ctx.obj["yes"] - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - if not yes and not click.confirm("running?"): - return - - conf = ctx.obj["config"] - runtime_config_path_prefix = \ - os.path.abspath(conf.fate_base) + "/python/fate_test/fate_test/scripts/op_test/spdz_conf/" - - params = dict(data_num=data_num, seed=seed, data_partition=data_partition, - test_round=test_round, data_lower_bound=data_lower_bound, - data_upper_bound=data_upper_bound) - - flow_address = None - for idx, address in enumerate(conf.serving_setting["flow_services"]): - if conf.role["guest"][0] in address["parties"]: - flow_address = address["address"] - - spdz_test = SPDZTest(params=params, - conf_path=runtime_config_path_prefix + "job_conf.json", - dsl_path=runtime_config_path_prefix + "job_dsl.json", - flow_address=flow_address, - guest_party_id=[conf.role["guest"][0]], - host_party_id=[conf.role["host"][0]]) - - tables = spdz_test.run() - for table in tables: - echo.echo(table) - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') diff --git a/python/fate_test/fate_test/scripts/testsuite_cli.py b/python/fate_test/fate_test/scripts/testsuite_cli.py deleted file mode 100644 index 913be2846a..0000000000 --- a/python/fate_test/fate_test/scripts/testsuite_cli.py +++ /dev/null @@ -1,316 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import time -import uuid -from datetime import timedelta - -import click - -from fate_test import _config -from fate_test._client import Clients -from fate_test._config import Config -from fate_test._flow_client import JobProgress, SubmitJobResponse, QueryJobResponse -from fate_test._io import LOGGER, echo -from fate_test._parser import JSON_STRING, Testsuite, non_success_summary -from fate_test.scripts._options import SharedOptions -from fate_test.scripts._utils import _load_testsuites, _upload_data, _delete_data, _load_module_from_script, \ - _add_replace_hook - - -@click.command("suite") -@click.option('-i', '--include', required=True, type=click.Path(exists=True), multiple=True, metavar="", - help="include *testsuite.json under these paths") -@click.option('-e', '--exclude', type=click.Path(exists=True), multiple=True, - help="exclude *testsuite.json under these paths") -@click.option('-r', '--replace', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in data/conf/dsl") -@click.option("-g", '--glob', type=str, - help="glob string to filter sub-directory of path specified by ") -@click.option('-m', '--timeout', type=int, default=3600, help="maximun running time of job") -@click.option('-p', '--task-cores', type=int, help="processors per node") -@click.option('-uj', '--update-job-parameters', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in conf.job_parameters") -@click.option('-uc', '--update-component-parameters', default="{}", type=JSON_STRING, - help="a json string represents mapping for replacing fields in conf.component_parameters") -@click.option("--skip-dsl-jobs", is_flag=True, default=False, - help="skip dsl jobs defined in testsuite") -@click.option("--skip-pipeline-jobs", is_flag=True, default=False, - help="skip pipeline jobs defined in testsuite") -@click.option("--skip-data", is_flag=True, default=False, - help="skip uploading data specified in testsuite") -@click.option("--data-only", is_flag=True, default=False, - help="upload data only") -@click.option("--provider", type=str, - help="Select the fate version, for example: fate@1.7") -@click.option("--disable-clean-data", "clean_data", flag_value=False, default=None) -@click.option("--enable-clean-data", "clean_data", flag_value=True, default=None) -@SharedOptions.get_shared_options(hidden=True) -@click.pass_context -def run_suite(ctx, replace, include, exclude, glob, timeout, update_job_parameters, update_component_parameters, - skip_dsl_jobs, skip_pipeline_jobs, skip_data, data_only, clean_data, task_cores, provider, **kwargs): - """ - process testsuite - """ - ctx.obj.update(**kwargs) - ctx.obj.post_process() - config_inst = ctx.obj["config"] - config_inst.extend_sid = ctx.obj["extend_sid"] - config_inst.auto_increasing_sid = ctx.obj["auto_increasing_sid"] - if clean_data is None: - clean_data = config_inst.clean_data - namespace = ctx.obj["namespace"] - yes = ctx.obj["yes"] - data_namespace_mangling = ctx.obj["namespace_mangling"] - # prepare output dir and json hooks - _add_replace_hook(replace) - echo.welcome() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - echo.echo("loading testsuites:") - suites = _load_testsuites(includes=include, excludes=exclude, glob=glob, provider=provider) - for suite in suites: - _config.jobs_num += len(suite.jobs) - echo.echo(f"\tdataset({len(suite.dataset)}) dsl jobs({len(suite.jobs)}) " - f"pipeline jobs ({len(suite.pipeline_jobs)}) {suite.path}") - if not yes and not click.confirm("running?"): - return - - echo.stdout_newline() - with Clients(config_inst) as client: - for i, suite in enumerate(suites): - # noinspection PyBroadException - try: - start = time.time() - echo.echo(f"[{i + 1}/{len(suites)}]start at {time.strftime('%Y-%m-%d %X')} {suite.path}", fg='red') - if not skip_data: - try: - _upload_data(client, suite, config_inst) - except Exception as e: - raise RuntimeError(f"exception occur while uploading data for {suite.path}") from e - if data_only: - continue - - if not skip_dsl_jobs: - echo.stdout_newline() - try: - time_consuming = _submit_job(client, suite, namespace, config_inst, timeout, - update_job_parameters, update_component_parameters, task_cores) - except Exception as e: - raise RuntimeError(f"exception occur while submit job for {suite.path}") from e - - if not skip_pipeline_jobs: - try: - _run_pipeline_jobs(config_inst, suite, namespace, data_namespace_mangling) - except Exception as e: - raise RuntimeError(f"exception occur while running pipeline jobs for {suite.path}") from e - - if not skip_data and clean_data: - _delete_data(client, suite) - echo.echo(f"[{i + 1}/{len(suites)}]elapse {timedelta(seconds=int(time.time() - start))}", fg='red') - if not skip_dsl_jobs or not skip_pipeline_jobs: - suite_file = str(suite.path).split("/")[-1] - echo.echo(suite.pretty_final_summary(time_consuming, suite_file)) - - except Exception: - exception_id = uuid.uuid1() - echo.echo(f"exception in {suite.path}, exception_id={exception_id}") - LOGGER.exception(f"exception id: {exception_id}") - finally: - echo.stdout_newline() - non_success_summary() - echo.farewell() - echo.echo(f"testsuite namespace: {namespace}", fg='red') - - -def _submit_job(clients: Clients, suite: Testsuite, namespace: str, config: Config, timeout, update_job_parameters, - update_component_parameters, task_cores): - # submit jobs - with click.progressbar(length=len(suite.jobs), - label="jobs ", - show_eta=False, - show_pos=True, - width=24) as bar: - time_list = [] - for job in suite.jobs_iter(): - job_progress = JobProgress(job.job_name) - start = time.time() - _config.jobs_progress += 1 - - def _raise(): - exception_id = str(uuid.uuid1()) - job_progress.exception(exception_id) - suite.update_status(job_name=job.job_name, exception_id=exception_id) - echo.file(f"exception({exception_id})") - LOGGER.exception(f"exception id: {exception_id}") - - # noinspection PyBroadException - try: - if task_cores is not None: - job.job_conf.update_job_common_parameters(task_cores=task_cores) - job.job_conf.update(config.parties, timeout, update_job_parameters, - update_component_parameters) - except Exception: - _raise() - continue - - def update_bar(n_step): - bar.item_show_func = lambda x: job_progress.show() - time.sleep(0.1) - bar.update(n_step) - - update_bar(1) - - def _call_back(resp: SubmitJobResponse): - if isinstance(resp, SubmitJobResponse): - progress_tracking = "/".join([str(_config.jobs_progress), str(_config.jobs_num)]) - if _config.jobs_num != len(suite.jobs): - job_progress.set_progress_tracking(progress_tracking) - job_progress.submitted(resp.job_id) - echo.file(f"[jobs] {resp.job_id} ", nl=False) - suite.update_status(job_name=job.job_name, job_id=resp.job_id) - - if isinstance(resp, QueryJobResponse): - job_progress.running(resp.status, resp.progress) - - update_bar(0) - - # noinspection PyBroadException - try: - response = clients["guest_0"].submit_job(job=job, callback=_call_back) - - # noinspection PyBroadException - try: - # add notes - notes = f"{job.job_name}@{suite.path}@{namespace}" - for role, party_id_list in job.job_conf.role.items(): - for i, party_id in enumerate(party_id_list): - clients[f"{role}_{i}"].add_notes(job_id=response.job_id, role=role, party_id=party_id, - notes=notes) - except Exception: - pass - except Exception: - _raise() - else: - job_progress.final(response.status) - job_name = job.job_name - suite.update_status(job_name=job_name, status=response.status.status) - if suite.model_in_dep(job_name): - _config.jobs_progress += 1 - if not response.status.is_success(): - suite.remove_dependency(job_name) - else: - dependent_jobs = suite.get_dependent_jobs(job_name) - for predict_job in dependent_jobs: - model_info, table_info, cache_info, model_loader_info = None, None, None, None - deps_data = _config.deps_alter[predict_job.job_name] - - if 'data_deps' in deps_data.keys() and deps_data.get('data', None) is not None and\ - job_name == deps_data.get('data_deps', None).get('name', None): - for k, v in deps_data.get('data'): - if job_name == k: - data_pre = v - roles = list(data_pre.keys()) - table_info, hierarchy = [], [] - for role_ in roles: - role, index = role_.split("_") - input_ = data_pre[role_] - for data_input, cpn in input_.items(): - try: - table_name = clients["guest_0"].output_data_table( - job_id=response.job_id, - role=role, - party_id=config.role[role][int(index)], - component_name=cpn) - except Exception: - _raise() - if predict_job.job_conf.dsl_version == 2: - hierarchy.append([role, index, data_input]) - table_info.append({'table': table_name}) - else: - hierarchy.append([role, 'args', 'data']) - table_info.append({data_input: [table_name]}) - table_info = {'hierarchy': hierarchy, 'table_info': table_info} - if 'model_deps' in deps_data.keys() and \ - job_name == deps_data.get('model_deps', None).get('name', None): - if predict_job.job_conf.dsl_version == 2: - # noinspection PyBroadException - try: - model_info = clients["guest_0"].deploy_model( - model_id=response.model_info["model_id"], - model_version=response.model_info["model_version"], - dsl=predict_job.job_dsl.as_dict()) - except Exception: - _raise() - else: - model_info = response.model_info - if 'cache_deps' in deps_data.keys() and \ - job_name == deps_data.get('cache_deps', None).get('name', None): - cache_dsl = predict_job.job_dsl.as_dict() - cache_info = [] - for cpn in cache_dsl.get("components").keys(): - if "CacheLoader" in cache_dsl.get("components").get(cpn).get("module"): - cache_info.append({cpn: {'job_id': response.job_id}}) - cache_info = {'hierarchy': [""], 'cache_info': cache_info} - - if 'model_loader_deps' in deps_data.keys() and \ - job_name == deps_data.get('model_loader_deps', None).get('name', None): - model_loader_dsl = predict_job.job_dsl.as_dict() - model_loader_info = [] - for cpn in model_loader_dsl.get("components").keys(): - if "ModelLoader" in model_loader_dsl.get("components").get(cpn).get("module"): - model_loader_info.append({cpn: response.model_info}) - model_loader_info = {'hierarchy': [""], 'model_loader_info': model_loader_info} - - suite.feed_dep_info(predict_job, job_name, model_info=model_info, table_info=table_info, - cache_info=cache_info, model_loader_info=model_loader_info) - suite.remove_dependency(job_name) - update_bar(0) - echo.stdout_newline() - time_list.append(time.time() - start) - return [str(int(i)) + "s" for i in time_list] - - -def _run_pipeline_jobs(config: Config, suite: Testsuite, namespace: str, data_namespace_mangling: bool): - # pipeline demo goes here - job_n = len(suite.pipeline_jobs) - for i, pipeline_job in enumerate(suite.pipeline_jobs): - echo.echo(f"Running [{i + 1}/{job_n}] job: {pipeline_job.job_name}") - - def _raise(err_msg, status="failed"): - exception_id = str(uuid.uuid1()) - suite.update_status(job_name=job_name, exception_id=exception_id, status=status) - echo.file(f"exception({exception_id}), error message:\n{err_msg}") - # LOGGER.exception(f"exception id: {exception_id}") - - job_name, script_path = pipeline_job.job_name, pipeline_job.script_path - mod = _load_module_from_script(script_path) - try: - if data_namespace_mangling: - try: - mod.main(config=config, namespace=f"_{namespace}") - suite.update_status(job_name=job_name, status="success") - except Exception as e: - _raise(e) - continue - else: - try: - mod.main(config=config) - suite.update_status(job_name=job_name, status="success") - except Exception as e: - _raise(e) - continue - except Exception as e: - _raise(e, status="not submitted") - continue diff --git a/python/fate_test/fate_test/utils.py b/python/fate_test/fate_test/utils.py deleted file mode 100644 index de33f91bf7..0000000000 --- a/python/fate_test/fate_test/utils.py +++ /dev/null @@ -1,348 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -import os - -from colorama import init, deinit, Fore, Style -import math -import numpy as np -from fate_test._io import echo -from prettytable import PrettyTable, ORGMODE - -SCRIPT_METRICS = "script_metrics" -DISTRIBUTION_METRICS = "distribution_metrics" -ALL = "all" -RELATIVE = "relative" -ABSOLUTE = "absolute" - - -class TxtStyle: - TRUE_VAL = Fore.GREEN - FALSE_VAL = Fore.RED + Style.BRIGHT - TITLE = Fore.BLUE - FIELD_VAL = Fore.YELLOW - DATA_FIELD_VAL = Fore.CYAN - END = Style.RESET_ALL - - -def show_data(data): - data_table = PrettyTable() - data_table.set_style(ORGMODE) - data_table.field_names = ["Data", "Information"] - for name, table_name in data.items(): - row = [name, f"{TxtStyle.DATA_FIELD_VAL}{table_name}{TxtStyle.END}"] - data_table.add_row(row) - echo.echo(data_table.get_string(title=f"{TxtStyle.TITLE}Data Summary{TxtStyle.END}")) - echo.echo("\n") - - -def _get_common_metrics(**results): - common_metrics = None - for result in results.values(): - if common_metrics is None: - common_metrics = set(result.keys()) - else: - common_metrics = common_metrics & result.keys() - if SCRIPT_METRICS in common_metrics: - common_metrics.remove(SCRIPT_METRICS) - return list(common_metrics) - - -def _filter_results(metrics, **results): - filtered_results = {} - for model_name, result in results.items(): - model_result = [result.get(metric, None) for metric in metrics] - if None in model_result: - continue - filtered_results[model_name] = model_result - return filtered_results - - -def style_table(txt): - colored_txt = txt.replace("True", f"{TxtStyle.TRUE_VAL}True{TxtStyle.END}") - colored_txt = colored_txt.replace("False", f"{TxtStyle.FALSE_VAL}False{TxtStyle.END}") - return colored_txt - - -def evaluate_almost_equal(metrics, results, abs_tol=None, rel_tol=None): - """ - Evaluate for each given metric if values in results are almost equal - Parameters - ---------- - metrics: List[str], metrics names - results: dict, results to be evaluated - abs_tol: float, absolute error tolerance - rel_tol: float, relative difference tolerance - Returns - ------- - bool, return True if all metrics in results are almost equal - """ - # return False if empty - if len(metrics) == 0: - return False - eval_summary = {} - for i, metric in enumerate(metrics): - v_eval = [res[i] for res in results.values()] - first_v = v_eval[0] - if metric == SCRIPT_METRICS: - continue - if abs_tol is not None and rel_tol is not None: - eval_summary[metric] = all(math.isclose(v, first_v, abs_tol=abs_tol, rel_tol=rel_tol) for v in v_eval) - elif abs_tol is not None: - eval_summary[metric] = all(math.isclose(v, first_v, abs_tol=abs_tol) for v in v_eval) - elif rel_tol is not None: - eval_summary[metric] = all(math.isclose(v, first_v, rel_tol=rel_tol) for v in v_eval) - else: - eval_summary[metric] = all(math.isclose(v, first_v) for v in v_eval) - all_match = all(eval_summary.values()) - return eval_summary, all_match - - -def _distribution_metrics(**results): - filtered_metric_group = _filter_results([DISTRIBUTION_METRICS], **results) - for script, model_results_pair in filtered_metric_group.items(): - metric_results = model_results_pair[0] - common_metrics = _get_common_metrics(**metric_results) - filtered_results = _filter_results(common_metrics, **metric_results) - table = PrettyTable() - table.set_style(ORGMODE) - script_model_names = list(filtered_results.keys()) - table.field_names = ["Script Model Name"] + common_metrics - for script_model_name in script_model_names: - row = [f"{script}-{script_model_name}"] + [f"{TxtStyle.FIELD_VAL}{v}{TxtStyle.END}" for v in - filtered_results[script_model_name]] - table.add_row(row) - echo.echo(table.get_string(title=f"{TxtStyle.TITLE}{script} distribution metrics{TxtStyle.END}")) - echo.echo("\n" + "#" * 60) - - -def match_script_metrics(abs_tol, rel_tol, match_details, **results): - filtered_metric_group = _filter_results([SCRIPT_METRICS], **results) - for script, model_results_pair in filtered_metric_group.items(): - metric_results = model_results_pair[0] - common_metrics = _get_common_metrics(**metric_results) - filtered_results = _filter_results(common_metrics, **metric_results) - table = PrettyTable() - table.set_style(ORGMODE) - script_model_names = list(filtered_results.keys()) - table.field_names = ["Script Model Name"] + common_metrics - for script_model_name in script_model_names: - row = [f"{script_model_name}-{script}"] + [f"{TxtStyle.FIELD_VAL}{v}{TxtStyle.END}" for v in - filtered_results[script_model_name]] - table.add_row(row) - echo.echo(table.get_string(title=f"{TxtStyle.TITLE}{script} Script Metrics Summary{TxtStyle.END}")) - _all_match(common_metrics, filtered_results, abs_tol, rel_tol, script, match_details=match_details) - - -def match_metrics(evaluate, group_name, abs_tol=None, rel_tol=None, storage_tag=None, history_tag=None, - fate_version=None, cache_directory=None, match_details=None, **results): - """ - Get metrics - Parameters - ---------- - evaluate: bool, whether to evaluate metrics are almost equal, and include compare results in output report - group_name: str, group name of all models - abs_tol: float, max tolerance of absolute error to consider two metrics to be almost equal - rel_tol: float, max tolerance of relative difference to consider two metrics to be almost equal - storage_tag: str, metrics information storage tag - history_tag: str, historical metrics information comparison tag - fate_version: str, FATE version - cache_directory: str, Storage path of metrics information - match_details: str, Error value display in algorithm comparison - results: dict of model name: metrics - Returns - ------- - match result - """ - init(autoreset=True) - common_metrics = _get_common_metrics(**results) - filtered_results = _filter_results(common_metrics, **results) - table = PrettyTable() - table.set_style(ORGMODE) - model_names = list(filtered_results.keys()) - table.field_names = ["Model Name"] + common_metrics - for model_name in model_names: - row = [f"{model_name}-{group_name}"] + [f"{TxtStyle.FIELD_VAL}{v}{TxtStyle.END}" for v in - filtered_results[model_name]] - table.add_row(row) - echo.echo(table.get_string(title=f"{TxtStyle.TITLE}Metrics Summary{TxtStyle.END}")) - - if evaluate and len(filtered_results.keys()) > 1: - _all_match(common_metrics, filtered_results, abs_tol, rel_tol, match_details=match_details) - - _distribution_metrics(**results) - match_script_metrics(abs_tol, rel_tol, match_details, **results) - if history_tag: - history_tag = ["_".join([i, group_name]) for i in history_tag] - comparison_quality(group_name, history_tag, cache_directory, abs_tol, rel_tol, match_details, **results) - if storage_tag: - storage_tag = "_".join(['FATE', fate_version, storage_tag, group_name]) - _save_quality(storage_tag, cache_directory, **results) - deinit() - - -def _match_error(metrics, results): - relative_error_list = [] - absolute_error_list = [] - if len(metrics) == 0: - return False - for i, v in enumerate(metrics): - v_eval = [res[i] for res in results.values()] - absolute_error_list.append(f"{TxtStyle.FIELD_VAL}{abs(max(v_eval) - min(v_eval))}{TxtStyle.END}") - relative_error_list.append( - f"{TxtStyle.FIELD_VAL}{abs((max(v_eval) - min(v_eval)) / max(v_eval))}{TxtStyle.END}") - return relative_error_list, absolute_error_list - - -def _all_match(common_metrics, filtered_results, abs_tol, rel_tol, script=None, match_details=None): - eval_summary, all_match = evaluate_almost_equal(common_metrics, filtered_results, abs_tol, rel_tol) - eval_table = PrettyTable() - eval_table.set_style(ORGMODE) - field_names = ["Metric", "All Match"] - relative_error_list, absolute_error_list = _match_error(common_metrics, filtered_results) - for i, metric in enumerate(eval_summary.keys()): - row = [metric, eval_summary.get(metric)] - if match_details == ALL: - field_names = ["Metric", "All Match", "max_relative_error", "max_absolute_error"] - row += [relative_error_list[i], absolute_error_list[i]] - elif match_details == RELATIVE: - field_names = ["Metric", "All Match", "max_relative_error"] - row += [relative_error_list[i]] - elif match_details == ABSOLUTE: - field_names = ["Metric", "All Match", "max_absolute_error"] - row += [absolute_error_list[i]] - eval_table.add_row(row) - eval_table.field_names = field_names - - echo.echo(style_table(eval_table.get_string(title=f"{TxtStyle.TITLE}Match Results{TxtStyle.END}"))) - script = "" if script is None else f"{script} " - if all_match: - echo.echo(f"All {script}Metrics Match: {TxtStyle.TRUE_VAL}{all_match}{TxtStyle.END}") - else: - echo.echo(f"All {script}Metrics Match: {TxtStyle.FALSE_VAL}{all_match}{TxtStyle.END}") - - -def comparison_quality(group_name, history_tags, cache_directory, abs_tol, rel_tol, match_details, **results): - def regression_group(results_dict): - metric = {} - for k, v in results_dict.items(): - if not isinstance(v, dict): - metric[k] = v - return metric - - def class_group(class_dict): - metric = {} - for k, v in class_dict.items(): - if not isinstance(v, dict): - metric[k] = v - for k, v in class_dict['distribution_metrics'].items(): - metric.update(v) - return metric - - history_info_dir = "/".join([os.path.join(os.path.abspath(cache_directory), 'benchmark_history', - "benchmark_quality.json")]) - assert os.path.exists(history_info_dir), f"Please check the {history_info_dir} Is it deleted" - with open(history_info_dir, 'r') as f: - benchmark_quality = json.load(f, object_hook=dict) - regression_metric = {} - regression_quality = {} - class_quality = {} - for history_tag in history_tags: - for tag in benchmark_quality: - if '_'.join(tag.split("_")[2:]) == history_tag and SCRIPT_METRICS in results["FATE"]: - regression_metric[tag] = regression_group(benchmark_quality[tag]['FATE']) - for key, value in _filter_results([SCRIPT_METRICS], **benchmark_quality[tag])['FATE'][0].items(): - regression_quality["_".join([tag, key])] = value - elif '_'.join(tag.split("_")[2:]) == history_tag and DISTRIBUTION_METRICS in results["FATE"]: - class_quality[tag] = class_group(benchmark_quality[tag]['FATE']) - - if SCRIPT_METRICS in results["FATE"] and regression_metric: - regression_metric[group_name] = regression_group(results['FATE']) - metric_compare(abs_tol, rel_tol, match_details, **regression_metric) - for key, value in _filter_results([SCRIPT_METRICS], **results)['FATE'][0].items(): - regression_quality["_".join([group_name, key])] = value - metric_compare(abs_tol, rel_tol, match_details, **regression_quality) - echo.echo("\n" + "#" * 60) - elif DISTRIBUTION_METRICS in results["FATE"] and class_quality: - - class_quality[group_name] = class_group(results['FATE']) - metric_compare(abs_tol, rel_tol, match_details, **class_quality) - echo.echo("\n" + "#" * 60) - - -def metric_compare(abs_tol, rel_tol, match_details, **metric_results): - common_metrics = _get_common_metrics(**metric_results) - filtered_results = _filter_results(common_metrics, **metric_results) - table = PrettyTable() - table.set_style(ORGMODE) - script_model_names = list(filtered_results.keys()) - table.field_names = ["Script Model Name"] + common_metrics - for script_model_name in script_model_names: - table.add_row([f"{script_model_name}"] + - [f"{TxtStyle.FIELD_VAL}{v}{TxtStyle.END}" for v in filtered_results[script_model_name]]) - print( - table.get_string(title=f"{TxtStyle.TITLE}Comparison results of all metrics of Script Model FATE{TxtStyle.END}")) - _all_match(common_metrics, filtered_results, abs_tol, rel_tol, match_details=match_details) - - -def _save_quality(storage_tag, cache_directory, **results): - save_dir = "/".join([os.path.join(os.path.abspath(cache_directory), 'benchmark_history', "benchmark_quality.json")]) - os.makedirs(os.path.dirname(save_dir), exist_ok=True) - if os.path.exists(save_dir): - with open(save_dir, 'r') as f: - benchmark_quality = json.load(f, object_hook=dict) - else: - benchmark_quality = {} - if storage_tag in benchmark_quality: - print("This tag already exists in the history and will be updated to the record information.") - benchmark_quality.update({storage_tag: results}) - try: - with open(save_dir, 'w') as fp: - json.dump(benchmark_quality, fp, indent=2) - print("Storage success, please check: ", save_dir) - except Exception: - print("Storage failed, please check: ", save_dir) - - -def parse_summary_result(rs_dict): - for model_key in rs_dict: - rs_content = rs_dict[model_key] - if 'validate' in rs_content: - return rs_content['validate'] - else: - return rs_content['train'] - - -def extract_data(df, col_name, convert_float=True, keep_id=False): - """ - component output data to numpy array - Parameters - ---------- - df: dataframe - col_name: column to extract - convert_float: whether to convert extracted value to float value - keep_id: whether to keep id - Returns - ------- - array of extracted data, optionally with id - """ - if keep_id: - if convert_float: - df[col_name] = df[col_name].to_numpy().astype(np.float64) - - return df[[df.columns[0], col_name]].to_numpy() - else: - return df[col_name].to_numpy().astype(np.float64) diff --git a/python/fate_test/pyproject.toml b/python/fate_test/pyproject.toml deleted file mode 100644 index 988f72fe3e..0000000000 --- a/python/fate_test/pyproject.toml +++ /dev/null @@ -1,45 +0,0 @@ -[tool.poetry] -name = "fate_test" -version = "1.9.0" -description = "test tools for FATE" -authors = ["FederatedAI "] -license = "Apache-2.0" - -homepage = "https://fate.fedai.org/" -repository = "https://github.com/FederatedAI/FATE" -documentation = "https://fate.readthedocs.io/en/latest/?badge=latest" -keywords = ["FATE", "Federated Learning", "Testsuite"] - -classifiers = [ - "Development Status :: 5 - Production/Stable", - "Environment :: Console", - "Topic :: Software Development :: Testing", - "Intended Audience :: Developers", - "Intended Audience :: Education" -] - -packages = [ - { include = "fate_test" } -] - -[tool.poetry.dependencies] -python = "^3.6" -requests_toolbelt = "^0.9.1" -requests = "^2.24.0" -click = "^7.1.2" -"ruamel.yaml" = "^0.16.10" -loguru = ">=0.6.0" -prettytable = "^1.0.0" -sshtunnel = "^0.1.5" -fate_client = "^1.9" -pandas = ">=1.1.5" -colorama = "^0.4.4" - -[tool.poetry.dev-dependencies] - -[tool.poetry.scripts] -fate_test = "fate_test.scripts.cli:cli" - -[build-system] -requires = ["poetry>=0.12", "setuptools>=50.0,<51.0"] -build-backend = "poetry.masonry.api" diff --git a/python/fate_test/setup.py b/python/fate_test/setup.py deleted file mode 100644 index 83c7a65aa9..0000000000 --- a/python/fate_test/setup.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -from setuptools import setup - -packages = ["fate_test", "fate_test.scripts", "fate_test.scripts.op_test", "fate_test.flow_test"] - -package_data = {"": ["*"]} - -install_requires = [ - "click>=7.1.2,<8.0.0", - "fate_client>=1.9,<2.0", - "loguru>=0.6.0", - "pandas>=1.1.5", - "poetry>=0.12", - "prettytable>=1.0.0,<2.0.0", - "requests>=2.24.0,<3.0.0", - "requests_toolbelt>=0.9.1,<0.10.0", - "ruamel.yaml>=0.16.10,<0.17.0", - "sshtunnel>=0.1.5,<0.2.0", - 'colorama>=0.4.4' -] - -entry_points = {"console_scripts": ["fate_test = fate_test.scripts.cli:cli"]} - -setup_kwargs = { - "name": "fate-test", - "version": "1.9.0", - "description": "test tools for FATE", - "long_description": 'FATE Test\n=========\n\nA collection of useful tools to running FATE\'s test.\n\n.. image:: images/tutorial.gif\n :align: center\n :alt: tutorial\n\nquick start\n-----------\n\n1. (optional) create virtual env\n\n .. code-block:: bash\n\n python -m venv venv\n source venv/bin/activate\n pip install -U pip\n\n\n2. install fate_test\n\n .. code-block:: bash\n\n pip install fate_test\n fate_test --help\n\n\n3. edit default fate_test_config.yaml\n\n .. code-block:: bash\n\n # edit priority config file with system default editor\n # filling some field according to comments\n fate_test config edit\n\n4. configure FATE-Pipeline and FATE-Flow Commandline server setting\n\n.. code-block:: bash\n\n # configure FATE-Pipeline server setting\n pipeline init --port 9380 --ip 127.0.0.1\n # configure FATE-Flow Commandline server setting\n flow init --port 9380 --ip 127.0.0.1\n\n5. run some fate_test suite\n\n .. code-block:: bash\n\n fate_test suite -i \n\n\n6. run some fate_test benchmark\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i \n\n7. useful logs or exception will be saved to logs dir with namespace shown in last step\n\ndevelop install\n---------------\nIt is more convenient to use the editable mode during development: replace step 2 with flowing steps\n\n.. code-block:: bash\n\n pip install -e ${FATE}/python/fate_client && pip install -e ${FATE}/python/fate_test\n\n\n\ncommand types\n-------------\n\n- suite: used for running testsuites, collection of FATE jobs\n\n .. code-block:: bash\n\n fate_test suite -i \n\n\n- benchmark-quality used for comparing modeling quality between FATE and other machine learning systems\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i \n\n\n\nconfiguration by examples\n--------------------------\n\n1. no need ssh tunnel:\n\n - 9999, service: service_a\n - 10000, service: service_b\n\n and both service_a, service_b can be requested directly:\n\n .. code-block:: yaml\n\n work_mode: 1 # 0 for standalone, 1 for cluster\n data_base_dir: \n parties:\n guest: [10000]\n host: [9999, 10000]\n arbiter: [9999]\n services:\n - flow_services:\n - {address: service_a, parties: [9999]}\n - {address: service_b, parties: [10000]}\n\n2. need ssh tunnel:\n\n - 9999, service: service_a\n - 10000, service: service_b\n\n service_a, can be requested directly while service_b don\'t,\n but you can request service_b in other node, say B:\n\n .. code-block:: yaml\n\n work_mode: 0 # 0 for standalone, 1 for cluster\n data_base_dir: \n parties:\n guest: [10000]\n host: [9999, 10000]\n arbiter: [9999]\n services:\n - flow_services:\n - {address: service_a, parties: [9999]}\n - flow_services:\n - {address: service_b, parties: [10000]}\n ssh_tunnel: # optional\n enable: true\n ssh_address: :\n ssh_username: \n ssh_password: # optional\n ssh_priv_key: "~/.ssh/id_rsa"\n\n\nTestsuite\n---------\n\nTestsuite is used for running a collection of jobs in sequence. Data used for jobs could be uploaded before jobs are\nsubmitted, and are cleaned when jobs finished. This tool is useful for FATE\'s release test.\n\ncommand options\n~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n fate_test suite --help\n\n1. include:\n\n .. code-block:: bash\n\n fate_test suite -i \n\n will run testsuites in *path1*\n\n2. exclude:\n\n .. code-block:: bash\n\n fate_test suite -i -e -e ...\n\n will run testsuites in *path1* but not in *path2* and *path3*\n\n3. glob:\n\n .. code-block:: bash\n\n fate_test suite -i -g "hetero*"\n\n will run testsuites in sub directory start with *hetero* of *path1*\n\n4. replace:\n\n .. code-block:: bash\n\n fate_test suite -i -r \'{"maxIter": 5}\'\n\n will find all key-value pair with key "maxIter" in `data conf` or `conf` or `dsl` and replace the value with 5\n\n\n5. skip-data:\n\n .. code-block:: bash\n\n fate_test suite -i --skip-data\n\n will run testsuites in *path1* without uploading data specified in *benchmark.json*.\n\n\n6. yes:\n\n .. code-block:: bash\n\n fate_test suite -i --yes\n\n will run testsuites in *path1* directly, skipping double check\n\n7. skip-dsl-jobs:\n\n .. code-block:: bash\n\n fate_test suite -i --skip-dsl-jobs\n\n will run testsuites in *path1* but skip all *tasks* in testsuites. It\'s would be useful when only pipeline tasks needed.\n\n8. skip-pipeline-jobs:\n\n .. code-block:: bash\n\n fate_test suite -i --skip-pipeline-jobs\n\n will run testsuites in *path1* but skip all *pipeline tasks* in testsuites. It\'s would be useful when only dsl tasks needed.\n\n\nBenchmark Quality\n------------------\n\nBenchmark-quality is used for comparing modeling quality between FATE\nand other machine learning systems. Benchmark produces a metrics comparison\nsummary for each benchmark job group.\n\n.. code-block:: bash\n\n fate_test benchmark-quality -i examples/benchmark_quality/hetero_linear_regression\n\n.. code-block:: bash\n\n +-------+--------------------------------------------------------------+\n | Data | Name |\n +-------+--------------------------------------------------------------+\n | train | {\'guest\': \'motor_hetero_guest\', \'host\': \'motor_hetero_host\'} |\n | test | {\'guest\': \'motor_hetero_guest\', \'host\': \'motor_hetero_host\'} |\n +-------+--------------------------------------------------------------+\n +------------------------------------+--------------------+--------------------+-------------------------+---------------------+\n | Model Name | explained_variance | r2_score | root_mean_squared_error | mean_squared_error |\n +------------------------------------+--------------------+--------------------+-------------------------+---------------------+\n | local-linear_regression-regression | 0.9035168452250094 | 0.9035070863155368 | 0.31340413289880553 | 0.09822215051805216 |\n | FATE-linear_regression-regression | 0.903146386539082 | 0.9031411831961411 | 0.3139977881119483 | 0.09859461093919596 |\n +------------------------------------+--------------------+--------------------+-------------------------+---------------------+\n +-------------------------+-----------+\n | Metric | All Match |\n +-------------------------+-----------+\n | explained_variance | True |\n | r2_score | True |\n | root_mean_squared_error | True |\n | mean_squared_error | True |\n +-------------------------+-----------+\n\ncommand options\n~~~~~~~~~~~~~~~\n\nuse the following command to show help message\n\n.. code-block:: bash\n\n fate_test benchmark-quality --help\n\n1. include:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i \n\n will run benchmark testsuites in *path1*\n\n2. exclude:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i -e -e ...\n\n will run benchmark testsuites in *path1* but not in *path2* and *path3*\n\n3. glob:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i -g "hetero*"\n\n will run benchmark testsuites in sub directory start with *hetero* of *path1*\n\n4. tol:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i -t 1e-3\n\n will run benchmark testsuites in *path1* with absolute tolerance of difference between metrics set to 0.001.\n If absolute difference between metrics is smaller than *tol*, then metrics are considered\n almost equal. Check benchmark testsuite `writing guide <#benchmark-testsuite>`_ on setting alternative tolerance.\n\n5. skip-data:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i --skip-data\n\n will run benchmark testsuites in *path1* without uploading data specified in *benchmark.json*.\n\n\n6. yes:\n\n .. code-block:: bash\n\n fate_test benchmark-quality -i --yes\n\n will run benchmark testsuites in *path1* directly, skipping double check\n\n\nbenchmark testsuite\n~~~~~~~~~~~~~~~~~~~\n\nConfiguration of jobs should be specified in a benchmark testsuite whose file name ends\nwith "\\*benchmark.json". For benchmark testsuite example,\nplease refer `here <../../examples/benchmark_quality>`_.\n\nA benchmark testsuite includes the following elements:\n\n- data: list of local data to be uploaded before running FATE jobs\n\n - file: path to original data file to be uploaded, should be relative to testsuite or FATE installation path\n - head: whether file includes header\n - partition: number of partition for data storage\n - table_name: table name in storage\n - namespace: table namespace in storage\n - role: which role to upload the data, as specified in fate_test.config;\n naming format is: "{role_type}_{role_index}", index starts at 0\n\n .. code-block:: json\n\n "data": [\n {\n "file": "examples/data/motor_hetero_host.csv",\n "head": 1,\n "partition": 8,\n "table_name": "motor_hetero_host",\n "namespace": "experiment",\n "role": "host_0"\n }\n ]\n\n- job group: each group includes arbitrary number of jobs with paths to corresponding script and configuration\n\n - job: name of job to be run, must be unique within each group list\n\n - script: path to `testing script <#testing-script>`_, should be relative to testsuite\n - conf: path to job configuration file for script, should be relative to testsuite\n\n .. code-block:: json\n\n "local": {\n "script": "./local-linr.py",\n "conf": "./linr_config.yaml"\n }\n\n - compare_setting: additional setting for quality metrics comparison, currently only takes ``relative_tol``\n\n If metrics *a* and *b* satisfy *abs(a-b) <= max(relative_tol \\* max(abs(a), abs(b)), absolute_tol)*\n (from `math module `_),\n they are considered almost equal. In the below example, metrics from "local" and "FATE" jobs are\n considered almost equal if their relative difference is smaller than\n *0.05 \\* max(abs(local_metric), abs(pipeline_metric)*.\n\n .. code-block:: json\n\n "linear_regression-regression": {\n "local": {\n "script": "./local-linr.py",\n "conf": "./linr_config.yaml"\n },\n "FATE": {\n "script": "./fate-linr.py",\n "conf": "./linr_config.yaml"\n },\n "compare_setting": {\n "relative_tol": 0.01\n }\n }\n\n\ntesting script\n~~~~~~~~~~~~~~\n\nAll job scripts need to have ``Main`` function as an entry point for executing jobs; scripts should\nreturn two dictionaries: first with data information key-value pairs: {data_type}: {data_name_dictionary};\nthe second contains {metric_name}: {metric_value} key-value pairs for metric comparison.\n\nBy default, the final data summary shows the output from the job named "FATE"; if no such job exists,\ndata information returned by the first job is shown. For clear presentation, we suggest that user follow\nthis general `guideline <../../examples/data/README.md#data-set-naming-rule>`_ for data set naming. In the case of multi-host\ntask, consider numbering host as such:\n\n::\n\n {\'guest\': \'default_credit_homo_guest\',\n \'host_1\': \'default_credit_homo_host_1\',\n \'host_2\': \'default_credit_homo_host_2\'}\n\nReturned quality metrics of the same key are to be compared.\nNote that only **real-value** metrics can be compared.\n\n- FATE script: ``Main`` always has three inputs:\n\n - config: job configuration, `JobConfig <../fate_client/pipeline/utils/tools.py#L64>`_ object loaded from "fate_test_config.yaml"\n - param: job parameter setting, dictionary loaded from "conf" file specified in benchmark testsuite\n - namespace: namespace suffix, user-given *namespace* or generated timestamp string when using *namespace-mangling*\n\n- non-FATE script: ``Main`` always has one input:\n\n - param: job parameter setting, dictionary loaded from "conf" file specified in benchmark testsuite\n\n\ndata\n----\n\n`Data` sub-command is used for upload or delete dataset in suite\'s.\n\ncommand options\n~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n fate_test data --help\n\n1. include:\n\n .. code-block:: bash\n\n fate_test data [upload|delete] -i \n\n will upload/delete dataset in testsuites in *path1*\n\n2. exclude:\n\n .. code-block:: bash\n\n fate_test data [upload|delete] -i -e -e ...\n\n will upload/delete dataset in testsuites in *path1* but not in *path2* and *path3*\n\n3. glob:\n\n .. code-block:: bash\n\n fate_test data [upload|delete] -i -g "hetero*"\n\n will upload/delete dataset in testsuites in sub directory start with *hetero* of *path1*\n\n\nfull command options\n---------------------\n\n.. click:: fate_test.scripts.cli:cli\n :prog: fate_test\n :show-nested:\n', - "author": "FederatedAI", - "author_email": "contact@FedAI.org", - "maintainer": None, - "maintainer_email": None, - "url": "https://fate.fedai.org/", - "packages": packages, - "package_data": package_data, - "install_requires": install_requires, - "entry_points": entry_points, - "python_requires": ">=3.6,<4.0", -} - - -setup(**setup_kwargs)