From 1cf1c7ec78750374fc4ab6a996d4a00708232950 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 27 Nov 2023 11:41:01 +0800 Subject: [PATCH 01/42] add sshe lr Signed-off-by: Yu Wu --- python/fate/components/components/sshe_lr.py | 253 +++++++++++++++++++ python/fate/ml/mpc/sshe_lr.py | 77 +++++- 2 files changed, 323 insertions(+), 7 deletions(-) create mode 100644 python/fate/components/components/sshe_lr.py diff --git a/python/fate/components/components/sshe_lr.py b/python/fate/components/components/sshe_lr.py new file mode 100644 index 0000000000..5d3f4f2009 --- /dev/null +++ b/python/fate/components/components/sshe_lr.py @@ -0,0 +1,253 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools +from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params +from fate.ml.mpc.sshe_lr import SSHELogisticRegression + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST, ARBITER], provider="fate") +def sshe_lr(ctx, role): + ... + + +@sshe_lr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + ), + encrypted_reveal: cpn.parameter(type=bool, default=True, + desc="whether reveal encrypted result every epoch, if False, only reveal at the end of training"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), +): + logger.info(f"enter sshe lr train") + # temp code start + init_param = init_param.dict() + + train_model( + ctx, + role, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + tol, + early_stop, + init_param, + encrypted_reveal, + threshold, + warm_start_model + ) + + +@sshe_lr.predict() +def predict( + ctx, + role: Role, + # threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5), + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), +): + predict_from_model(ctx, role, input_model, test_data, test_output_data) + + +@sshe_lr.cross_validation() +def cross_validation( + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + encrypted_reveal: cpn.parameter(type=bool, default=True, + desc="whether reveal encrypted result every epoch, if False, only reveal at the end of training"), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["auc"]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), +): + init_param = init_param.dict() + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + if role.is_guest: + module = SSHELogisticRegression( + epochs=epochs, + batch_size=batch_size, + learning_rate=learning_rate, + tol=tol, + early_stop=early_stop, + init_param=init_param, + threshold=threshold, + encrypted_reveal=encrypted_reveal, + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_df = module.predict(sub_ctx, train_data) + """train_predict_result = transform_to_predict_result( + train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="train" + )""" + train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_df = module.predict(sub_ctx, validate_data) + """validate_predict_result = transform_to_predict_result( + validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, + data_type="predict" + )""" + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + + # evaluation = evaluate(predicted) + elif role.is_host: + module = SSHELogisticRegression( + epochs=epochs, + batch_size=batch_size, + init_param=init_param, + encrypted_reveal=encrypted_reveal, + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 + + +def train_model( + ctx, + role, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + tol, + early_stop, + learning_rate, + init_param, + encrypted_reveal, + threshold, + input_model +): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = SSHELogisticRegression.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + + else: + module = SSHELogisticRegression( + epochs=epochs, + batch_size=batch_size, + tol=tol, + early_stop=early_stop, + learning_rate=learning_rate, + init_param=init_param, + threshold=threshold, + encrypted_reveal=encrypted_reveal + ) + # optimizer = optimizer_factory(optimizer_param) + logger.info(f"sshe lr guest start train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + + if validate_data is not None: + logger.info(f"validate data provided") + validate_data = validate_data.read() + + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={}) + + sub_ctx = ctx.sub_ctx("predict") + + predict_df = module.predict(sub_ctx, train_data) + + if role.is_guest: + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + predict_df = module.predict(sub_ctx, validate_data) + if ctx.is_guest: + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([predict_result, validate_predict_result]) + train_output_data.write(predict_result) + + +def predict_from_model(ctx, role, input_model, test_data, test_output_data): + logger.info(f"sshe lr guest start predict") + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + module = SSHELogisticRegression.from_model(model) + # if module.threshold != 0.5: + # module.threshold = threshold + test_data = test_data.read() + predict_df = module.predict(sub_ctx, test_data) + if role.is_guest: + predict_result = tools.add_dataset_type(predict_df, consts.TEST_SET) + test_output_data.write(predict_result) diff --git a/python/fate/ml/mpc/sshe_lr.py b/python/fate/ml/mpc/sshe_lr.py index 6aaeb84be5..c04e3ee600 100644 --- a/python/fate/ml/mpc/sshe_lr.py +++ b/python/fate/ml/mpc/sshe_lr.py @@ -1,26 +1,87 @@ import logging - +from typing import Union from fate.arch import Context -from ..abc.module import Module from fate.arch.dataframe import DataFrame from fate.arch.protocol.mpc.nn.sshe.lr_layer import ( SSHELogisticRegressionLayer, SSHELogisticRegressionLossLayer, SSHEOptimizerSGD, ) +from ..abc.module import Module, Model, HeteroModule logger = logging.getLogger(__name__) class SSHELogisticRegression(Module): - def __init__(self, lr=0.05): - self.lr = lr + def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_param, + encrypted_reveal=True, threshold=0.5): + self.learning_rate = learning_rate + self.epochs = epochs + self.batch_size = batch_size + self.tol = tol + self.early_stop = early_stop + self.learning_rate = learning_rate + self.init_param = init_param + self.threshold = threshold + self.encrypted_reveal = encrypted_reveal + + self.estimator = None + self.ovr = False + self.labels = None + + def fit(self, ctx: Context, train_data: DataFrame, validate_data=None): + train_data_binarized_label = train_data.label.get_dummies() + label_count = train_data_binarized_label.shape[1] + ctx.hosts.put("label_count", label_count) + + def get_model(self): + all_estimator = {} + if self.ovr: + for label, estimator in self.estimator.items(): + all_estimator[label] = estimator.get_model() + else: + all_estimator = self.estimator.get_model() + return { + "data": {"estimator": all_estimator}, + "meta": { + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate": self.learning_rate, + "init_param": self.init_param, + "optimizer_param": self.optimizer_param, + "labels": self.labels, + "ovr": self.ovr, + "threshold": self.threshold, + "encrypted_reveal": self.encrypted_reveal, + }, + } + + def from_model(cls, model: Union[dict, Model]): + pass + + +class SSHELREstimator(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate=None, init_param=None, + encrypted_reveal=True): + self.epochs = epochs + self.batch_size = batch_size + self.optimizer = optimizer + self.lr = learning_rate + self.init_param = init_param + self.encrypted_reveal = encrypted_reveal + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + self.header = None + + def fit_binary_model(self, ctx: Context, train_data: DataFrame) -> None: - def fit(self, ctx: Context, input_data: DataFrame) -> None: rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank - y = ctx.mpc.cond_call(lambda: input_data.label.as_tensor(), lambda: None, dst=rank_b) - h = input_data.as_tensor() + y = ctx.mpc.cond_call(lambda: train_data.label.as_tensor(), lambda: None, dst=rank_b) + h = train_data.as_tensor() # generator = torch.Generator().manual_seed(0) layer = SSHELogisticRegressionLayer( ctx, @@ -35,6 +96,8 @@ def fit(self, ctx: Context, input_data: DataFrame) -> None: optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) for i in range(20): + # mpc encrypted [wx] + # to get decrypted wx: z.get_plain_text() z = layer(h) loss = loss_fn(z, y) if i % 3 == 0: From 8a355681ef1d7dbeea244e1f4bbc3ce3c0d93f9f Mon Sep 17 00:00:00 2001 From: sagewe Date: Tue, 28 Nov 2023 15:18:51 +0800 Subject: [PATCH 02/42] fix dependencies Signed-off-by: sagewe --- python/fate/arch/protocol/mpc/nn/__init__.py | 10 +++++----- python/fate/arch/tensor/distributed/_ops_binary.py | 2 +- python/fate/arch/utils/trace.py | 1 + python/fate/ml/mpc/sshe_lr.py | 11 ++++++++--- python/requirements-eggroll.txt | 4 ---- python/requirements-fate.txt | 8 ++++++-- 6 files changed, 21 insertions(+), 15 deletions(-) diff --git a/python/fate/arch/protocol/mpc/nn/__init__.py b/python/fate/arch/protocol/mpc/nn/__init__.py index fd3dee37c2..9a49b2ccd4 100644 --- a/python/fate/arch/protocol/mpc/nn/__init__.py +++ b/python/fate/arch/protocol/mpc/nn/__init__.py @@ -71,7 +71,7 @@ Unsqueeze, Where, ) -from .onnx_converter import from_onnx, from_pytorch, from_tensorflow, TF_AND_TF2ONNX +# from .onnx_converter import from_onnx, from_pytorch, from_tensorflow, TF_AND_TF2ONNX # expose contents of package __all__ = [ # noqa: F405 @@ -107,9 +107,9 @@ "Exp", "Expand", "Flatten", - "from_pytorch", - "from_onnx", - "from_tensorflow", + # "from_pytorch", + # "from_onnx", + # "from_tensorflow", "Gather", "Gemm", "GlobalAveragePool", @@ -142,7 +142,7 @@ "Squeeze", "Sub", "Sum", - "TF_AND_TF2ONNX", + # "TF_AND_TF2ONNX", "Transpose", "Unsqueeze", "Where", diff --git a/python/fate/arch/tensor/distributed/_ops_binary.py b/python/fate/arch/tensor/distributed/_ops_binary.py index bf74afdbbd..db65bcdc30 100644 --- a/python/fate/arch/tensor/distributed/_ops_binary.py +++ b/python/fate/arch/tensor/distributed/_ops_binary.py @@ -25,7 +25,7 @@ def mul(input, other): def _create_meta_tensor(x): if isinstance(x, (torch.Tensor, DTensor)): - return torch.rand(*x.shape, device=torch.device("meta"), dtype=x.dtype) + return torch.zeros(*x.shape, device=torch.device("meta"), dtype=x.dtype) else: return torch.tensor(x, device=torch.device("meta")) diff --git a/python/fate/arch/utils/trace.py b/python/fate/arch/utils/trace.py index a3874affa1..6202438417 100644 --- a/python/fate/arch/utils/trace.py +++ b/python/fate/arch/utils/trace.py @@ -17,6 +17,7 @@ def _is_tracing_enabled(): def setup_tracing(service_name, endpoint: str = None): if not _is_tracing_enabled(): return + from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace import TracerProvider diff --git a/python/fate/ml/mpc/sshe_lr.py b/python/fate/ml/mpc/sshe_lr.py index 6aaeb84be5..3950a81b20 100644 --- a/python/fate/ml/mpc/sshe_lr.py +++ b/python/fate/ml/mpc/sshe_lr.py @@ -1,5 +1,6 @@ import logging +import torch from fate.arch import Context from ..abc.module import Module @@ -21,7 +22,7 @@ def fit(self, ctx: Context, input_data: DataFrame) -> None: rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank y = ctx.mpc.cond_call(lambda: input_data.label.as_tensor(), lambda: None, dst=rank_b) h = input_data.as_tensor() - # generator = torch.Generator().manual_seed(0) + layer = SSHELogisticRegressionLayer( ctx, in_features_a=ctx.mpc.option_call(lambda: h.shape[1], dst=rank_a), @@ -29,15 +30,19 @@ def fit(self, ctx: Context, input_data: DataFrame) -> None: out_features=1, rank_a=rank_a, rank_b=rank_b, - # generator=generator, + wa_init_fn=lambda shape: torch.rand(shape), + wb_init_fn=lambda shape: torch.rand(shape), ) + loss_fn = SSHELogisticRegressionLossLayer(ctx, rank_a=rank_a, rank_b=rank_b) optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) - for i in range(20): + for i in range(1): z = layer(h) loss = loss_fn(z, y) if i % 3 == 0: logger.info(f"loss: {loss.get()}") loss.backward() optimizer.step() + wa = layer.wa.get_plain_text() + wb = layer.wb.get_plain_text() diff --git a/python/requirements-eggroll.txt b/python/requirements-eggroll.txt index e2b71ed0ae..2565931fbc 100644 --- a/python/requirements-eggroll.txt +++ b/python/requirements-eggroll.txt @@ -1,9 +1,5 @@ grpcio==1.46.3 grpcio-tools==1.46.3 -numba==0.56.4 protobuf==3.19.6 -pyarrow==6.0.1 -mmh3==3.0.0 -cachetools>=3.0.0 cloudpickle==2.1.0 psutil>=5.7.0 diff --git a/python/requirements-fate.txt b/python/requirements-fate.txt index 3a26f5e6aa..fc49498ac7 100644 --- a/python/requirements-fate.txt +++ b/python/requirements-fate.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cpu lmdb==1.3.0 -torch==1.13.1+cpu +torch==1.13.1 fate_utils pydantic==1.10.12 cloudpickle==2.1.0 @@ -17,4 +17,8 @@ grpcio==1.46.3 protobuf==3.19.6 scikit-learn omegaconf -``` +rich +opentelemetry-api +opentelemetry-sdk +opentelemetry-exporter-otlp-proto-grpc +mmh3==3.0.0 From ff99e19053879dbf890b36043663d82880fc4e1d Mon Sep 17 00:00:00 2001 From: sagewe Date: Fri, 1 Dec 2023 11:35:19 +0800 Subject: [PATCH 03/42] enhance profile Signed-off-by: sagewe --- python/fate/arch/computing/eggroll/_table.py | 2 +- .../fate/arch/computing/standalone/_table.py | 2 +- python/fate/arch/computing/table.py | 80 ++++++++++++------- python/fate/arch/context/_context.py | 2 + .../arch/launchers/multiprocess_launcher.py | 5 +- 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/python/fate/arch/computing/eggroll/_table.py b/python/fate/arch/computing/eggroll/_table.py index 728953ffa9..690bdb7521 100644 --- a/python/fate/arch/computing/eggroll/_table.py +++ b/python/fate/arch/computing/eggroll/_table.py @@ -42,7 +42,7 @@ def __init__(self, rp: RollPair): def engine(self): return self._engine - def _map_reduce_partitions_with_index( + def _impl_map_reduce_partitions_with_index( self, map_partition_op: Callable[[int, Iterable], Iterable], reduce_partition_op: Callable[[Any, Any], Any], diff --git a/python/fate/arch/computing/standalone/_table.py b/python/fate/arch/computing/standalone/_table.py index 9f12c32ad2..437cabcb2b 100644 --- a/python/fate/arch/computing/standalone/_table.py +++ b/python/fate/arch/computing/standalone/_table.py @@ -52,7 +52,7 @@ def _drop_num(self, num: int, partitioner): self._table.delete(k, partitioner=partitioner) return Table(table=self._table) - def _map_reduce_partitions_with_index( + def _impl_map_reduce_partitions_with_index( self, map_partition_op: Callable[[int, Iterable[Tuple[K, V]]], Iterable], reduce_partition_op: Callable[[Any, Any], Any], diff --git a/python/fate/arch/computing/table.py b/python/fate/arch/computing/table.py index fb9ac7aca9..ad8b124c19 100644 --- a/python/fate/arch/computing/table.py +++ b/python/fate/arch/computing/table.py @@ -9,6 +9,7 @@ from fate.arch.utils.trace import auto_trace from ..unify import URI import functools +from ._profile import computing_profile as _compute_info logger = logging.getLogger(__name__) @@ -26,26 +27,27 @@ def _add_padding(message, count): return "\n".join(padded_lines) -def _compute_info(func): - return func - - # @functools.wraps(func) - # def wrapper(*args, **kwargs): - # global _level - # logger.debug(_add_padding(f"computing enter {func.__name__}", _level * 2)) - # try: - # _level += 1 - # stacks = _add_padding("".join(traceback.format_stack(limit=5)[:-1]), _level * 2) - # logger.debug(f'{_add_padding("stack:", _level * 2)}\n{stacks}') - # return func(*args, **kwargs) - # finally: - # _level -= 1 - # logger.debug(f"{' ' * _level}computing exit {func.__name__}") - # - # return wrapper +# def _compute_info(func): +# return func +# +# # @functools.wraps(func) +# # def wrapper(*args, **kwargs): +# # global _level +# # logger.debug(_add_padding(f"computing enter {func.__name__}", _level * 2)) +# # try: +# # _level += 1 +# # stacks = _add_padding("".join(traceback.format_stack(limit=5)[:-1]), _level * 2) +# # logger.debug(f'{_add_padding("stack:", _level * 2)}\n{stacks}') +# # return func(*args, **kwargs) +# # finally: +# # _level -= 1 +# # logger.debug(f"{' ' * _level}computing exit {func.__name__}") +# # +# # return wrapper class KVTableContext: + @_compute_info def parallelize( self, data, include_key=True, partition=None, key_serdes_type=0, value_serdes_type=0, partitioner_type=0 ) -> "KVTable": @@ -130,7 +132,7 @@ def _drop_num(self, num: int, partitioner): raise NotImplementedError(f"{self.__class__.__name__}._drop_num") @abc.abstractmethod - def _map_reduce_partitions_with_index( + def _impl_map_reduce_partitions_with_index( self, map_partition_op: Callable[[int, Iterable[Tuple[K, V]]], Iterable], reduce_partition_op: Optional[Callable[[Any, Any], Any]], @@ -224,6 +226,26 @@ def map_reduce_partitions_with_index( output_value_serdes_type=None, output_partitioner_type=None, output_num_partitions=None, + ): + return self._map_reduce_partitions_with_index( + map_partition_op=map_partition_op, + reduce_partition_op=reduce_partition_op, + shuffle=shuffle, + output_key_serdes_type=output_key_serdes_type, + output_value_serdes_type=output_value_serdes_type, + output_partitioner_type=output_partitioner_type, + output_num_partitions=output_num_partitions, + ) + + def _map_reduce_partitions_with_index( + self, + map_partition_op: Callable[[int, Iterable], Iterable], + reduce_partition_op: Callable[[Any, Any], Any] = None, + shuffle=True, + output_key_serdes_type=None, + output_value_serdes_type=None, + output_partitioner_type=None, + output_num_partitions=None, ): if not shuffle and reduce_partition_op is not None: raise ValueError("when shuffle is False, it is not allowed to specify reduce_partition_op") @@ -241,7 +263,7 @@ def map_reduce_partitions_with_index( output_key_serdes = get_serdes_by_type(output_key_serdes_type) output_value_serdes = get_serdes_by_type(output_value_serdes_type) output_partitioner = get_partitioner_by_type(output_partitioner_type) - return self._map_reduce_partitions_with_index( + return self._impl_map_reduce_partitions_with_index( map_partition_op=_lifted_mpwi_map_to_serdes( map_partition_op, self.key_serdes, self.value_serdes, output_key_serdes, output_value_serdes ), @@ -271,7 +293,7 @@ def mapPartitionsWithIndex( output_value_serdes_type=None, output_partitioner_type=None, ): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( map_partition_op=map_partition_op, shuffle=True, output_key_serdes_type=output_key_serdes_type, @@ -290,7 +312,7 @@ def mapReducePartitions( output_value_serdes_type=None, output_partitioner_type=None, ): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( map_partition_op=_lifted_map_reduce_partitions_to_mpwi(map_partition_op), reduce_partition_op=reduce_partition_op, shuffle=shuffle, @@ -302,7 +324,7 @@ def mapReducePartitions( @auto_trace @_compute_info def applyPartitions(self, func, output_value_serdes_type=None): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( map_partition_op=_lifted_apply_partitions_to_mpwi(func), shuffle=False, output_key_serdes_type=self.key_serdes_type, @@ -316,7 +338,7 @@ def mapPartitions( ): if use_previous_behavior: raise NotImplementedError("use_previous_behavior is not supported") - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( map_partition_op=_lifted_map_partitions_to_mpwi(func), shuffle=not preserves_partitioning, output_key_serdes_type=self.key_serdes_type, @@ -332,7 +354,7 @@ def map( output_value_serdes_type=None, output_partitioner_type=None, ): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( _lifted_map_to_mpwi(map_op), shuffle=True, output_key_serdes_type=output_key_serdes_type, @@ -343,7 +365,7 @@ def map( @auto_trace @_compute_info def mapValues(self, map_value_op: Callable[[Any], Any], output_value_serdes_type=None): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( _lifted_map_values_to_mpwi(map_value_op), shuffle=False, output_key_serdes_type=self.key_serdes_type, @@ -363,7 +385,7 @@ def flatMap( output_key_serdes_type=None, output_value_serdes_type=None, ): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( _lifted_flat_map_to_mpwi(flat_map_op), shuffle=True, output_key_serdes_type=output_key_serdes_type, @@ -373,7 +395,7 @@ def flatMap( @auto_trace @_compute_info def filter(self, filter_op: Callable[[Any], bool]): - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( lambda i, x: ((k, v) for k, v in x if filter_op(v)), shuffle=False, output_key_serdes_type=self.key_serdes_type, @@ -381,7 +403,7 @@ def filter(self, filter_op: Callable[[Any], bool]): ) def _sample(self, fraction, seed=None) -> "KVTable": - return self.map_reduce_partitions_with_index( + return self._map_reduce_partitions_with_index( _lifted_sample_to_mpwi(fraction, seed), shuffle=False, output_key_serdes_type=self.key_serdes_type, @@ -522,7 +544,7 @@ def repartition(self, num_partitions, partitioner_type=None, key_serdes_type=Non output_key_serdes, self.value_serdes, ) - return self._map_reduce_partitions_with_index( + return self._impl_map_reduce_partitions_with_index( map_partition_op=mapper, reduce_partition_op=None, shuffle=True, diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 39bc75b324..5e5c16b191 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -22,6 +22,7 @@ from ._metrics import InMemoryMetricsHandler, MetricsWrap from ._namespace import NS, default_ns from ..unify import device +from fate.arch.utils.trace import auto_trace logger = logging.getLogger(__name__) @@ -263,6 +264,7 @@ def _get_computing(self): raise RuntimeError(f"computing not set") return self._computing + @auto_trace def destroy(self): if not self._is_destroyed: try: diff --git a/python/fate/arch/launchers/multiprocess_launcher.py b/python/fate/arch/launchers/multiprocess_launcher.py index 73e16d45cc..abfaf8db18 100644 --- a/python/fate/arch/launchers/multiprocess_launcher.py +++ b/python/fate/arch/launchers/multiprocess_launcher.py @@ -131,6 +131,7 @@ def _run_process( from fate.arch.utils.logger import set_up_logging from fate.arch.launchers.context_helper import init_context from fate.arch.utils.trace import setup_tracing + from fate.arch.computing._profile import profile_start, profile_ends if args.rank >= len(args.parties): raise ValueError(f"rank {args.rank} is out of range {len(args.parties)}") @@ -150,7 +151,9 @@ def _run_process( ctx = init_context() try: + profile_start() f(ctx) + profile_ends() output_or_exception_q.put((args.rank, None, None)) safe_to_exit.wait() @@ -185,7 +188,7 @@ def wait(self) -> int: def terminate(self): self.safe_to_exit.set() - time.sleep(1) # wait for 1 second to let all processes has a chance to exit + time.sleep(5) # wait for 1 second to let all processes has a chance to exit for process in self.processes: if process.is_alive(): process.terminate() From 1249dfae7ecab8a08b8872a2f773a777daf27830 Mon Sep 17 00:00:00 2001 From: sagewe Date: Fri, 1 Dec 2023 11:46:47 +0800 Subject: [PATCH 04/42] set default partitioner to mmh3 Signed-off-by: sagewe --- python/fate/arch/unify/partitioner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/fate/arch/unify/partitioner.py b/python/fate/arch/unify/partitioner.py index e871cff5f6..87c8273622 100644 --- a/python/fate/arch/unify/partitioner.py +++ b/python/fate/arch/unify/partitioner.py @@ -30,7 +30,8 @@ def _java_string_like_partitioner(key, total_partitions): def get_default_partitioner(): - return _java_string_like_partitioner + return mmh3_partitioner + # return _java_string_like_partitioner def get_partitioner_by_type(partitioner_type: int): From b5f2916d2668ca497d2605dd6e917e3026744af7 Mon Sep 17 00:00:00 2001 From: zhihuiwan <15779896112@163.com> Date: Mon, 4 Dec 2023 12:58:24 +0800 Subject: [PATCH 05/42] update dataframe transformer Signed-off-by: zhihuiwan <15779896112@163.com> --- .../fate/components/components/dataframe_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/fate/components/components/dataframe_transformer.py b/python/fate/components/components/dataframe_transformer.py index 8bafd30fb3..5d57c713db 100644 --- a/python/fate/components/components/dataframe_transformer.py +++ b/python/fate/components/components/dataframe_transformer.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from fate.components.core import LOCAL, Role, cpn +from fate.components.core import LOCAL, Role, cpn, GUEST, HOST -@cpn.component(roles=[LOCAL]) +@cpn.component(roles=[LOCAL, GUEST, HOST]) def dataframe_transformer( ctx, role: Role, - table: cpn.table_input(roles=[LOCAL]), - dataframe_output: cpn.dataframe_output(roles=[LOCAL]), + table: cpn.table_input(roles=[LOCAL, GUEST, HOST]), + dataframe_output: cpn.dataframe_output(roles=[LOCAL, GUEST, HOST]), namespace: cpn.parameter(type=str, default=None, optional=True), name: cpn.parameter(type=str, default=None, optional=True), site_name: cpn.parameter(type=str, default=None, optional=True), From b7d4a58b299286ea89e3b56876fd85b174377cab Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 4 Dec 2023 19:35:36 +0800 Subject: [PATCH 06/42] edit mpc Signed-off-by: Yu Wu --- python/fate/arch/protocol/mpc/mpc.py | 8 ++++---- python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py | 11 ++++++----- .../fate/arch/protocol/mpc/primitives/arithmetic.py | 13 +++++++------ python/fate/arch/tensor/distributed/_ops_binary.py | 2 +- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/fate/arch/protocol/mpc/mpc.py b/python/fate/arch/protocol/mpc/mpc.py index de3395a87a..997848b16e 100644 --- a/python/fate/arch/protocol/mpc/mpc.py +++ b/python/fate/arch/protocol/mpc/mpc.py @@ -171,13 +171,13 @@ def cpu(self): self.share = self.share.cpu() return self - def get_plain_text(self, dst=None): + def get_plain_text(self, dst=None, group=None): """Decrypts the tensor.""" - return self._tensor.get_plain_text(dst=dst) + return self._tensor.get_plain_text(dst=dst, group=group) - def reveal(self, dst=None): + def reveal(self, dst=None, group=None): """Decrypts the tensor without any downscaling.""" - return self._tensor.reveal(dst=dst) + return self._tensor.reveal(dst=dst, group=group) def __repr__(self): """Returns a representation of the tensor useful for debugging.""" diff --git a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py index 2ff939285a..ca8f56436b 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py @@ -25,7 +25,8 @@ def __init__( self.ctx = ctx self.rank_a = rank_a self.rank_b = rank_b - self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_aggregator_layer") + self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], + f"{ctx.namespace.federation_tag}.sshe_aggregator_layer") if sync_shape: ctx.mpc.option_assert(in_features_a is not None, "in_features_a must be specified", dst=rank_a) @@ -36,8 +37,8 @@ def __init__( ctx.mpc.option_assert( in_features_a is None, "in_features_a must be None when sync_shape is True", dst=rank_b ) - in_features_a = ctx.mpc.communicator.broadcast_obj(obj=in_features_a, src=rank_a) - in_features_b = ctx.mpc.communicator.broadcast_obj(obj=in_features_b, src=rank_b) + in_features_a = ctx.mpc.communicator.broadcast_obj(obj=in_features_a, src=rank_a, group=self.group) + in_features_b = ctx.mpc.communicator.broadcast_obj(obj=in_features_b, src=rank_b, group=self.group) else: ctx.mpc.option_assert( in_features_a is not None, "in_features_a must be specified when sync_shape is False", dst=rank_a @@ -148,7 +149,7 @@ def __call__(self, dz): class SSHELogisticRegressionLossLayer: def __init__(self, ctx: Context, rank_a, rank_b): self.ctx = ctx - self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_loss_layer") + self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], f"{ctx.namespace.federation_tag}.sshe_loss_layer") self.rank_a = rank_a self.rank_b = rank_b self.phe_cipher = ctx.cipher.phe.setup() @@ -197,7 +198,7 @@ def get(self): cipher_a=self.ctx.mpc.option(self.phe_cipher, self.rank_a), ) .mean() - .get_plain_text() + .get_plain_text(group=self.group) ) return 2 * dz_mean_square - 0.5 + torch.log(torch.tensor(2.0)) diff --git a/python/fate/arch/protocol/mpc/primitives/arithmetic.py b/python/fate/arch/protocol/mpc/primitives/arithmetic.py index 9c6823feed..b9f1df95c9 100644 --- a/python/fate/arch/protocol/mpc/primitives/arithmetic.py +++ b/python/fate/arch/protocol/mpc/primitives/arithmetic.py @@ -5,6 +5,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + # dependencies: import torch @@ -19,7 +21,6 @@ from fate.arch.protocol.mpc.encoder import FixedPointEncoder from fate.arch.protocol.mpc.functions import regular from . import beaver, replicated # noqa: F401 -import logging logger = logging.getLogger(__name__) @@ -277,20 +278,20 @@ def reveal_batch(tensor_or_list, dst=None): else: return comm.get().reduce(shares, dst, batched=True) - def reveal(self, dst=None): + def reveal(self, dst=None, group=None): """Decrypts the tensor without any downscaling.""" tensor = self.share.clone() if dst is None: - return comm.get().all_reduce(tensor) + return comm.get().all_reduce(tensor, group=group) else: - return comm.get().reduce(tensor, dst) + return comm.get().reduce(tensor, dst, group=group) - def get_plain_text(self, dst=None): + def get_plain_text(self, dst=None, group=None): """Decrypts the tensor.""" # Edge case where share becomes 0 sized (e.g. result of split) if self.nelement() < 1: return torch.empty(self.share.size()) - return self.encoder.decode(self.reveal(dst=dst)) + return self.encoder.decode(self.reveal(dst=dst, group=group)) def encode_(self, new_encoder): """Rescales the input to a new encoding in-place""" diff --git a/python/fate/arch/tensor/distributed/_ops_binary.py b/python/fate/arch/tensor/distributed/_ops_binary.py index bf74afdbbd..db65bcdc30 100644 --- a/python/fate/arch/tensor/distributed/_ops_binary.py +++ b/python/fate/arch/tensor/distributed/_ops_binary.py @@ -25,7 +25,7 @@ def mul(input, other): def _create_meta_tensor(x): if isinstance(x, (torch.Tensor, DTensor)): - return torch.rand(*x.shape, device=torch.device("meta"), dtype=x.dtype) + return torch.zeros(*x.shape, device=torch.device("meta"), dtype=x.dtype) else: return torch.tensor(x, device=torch.device("meta")) From 1ca773e41c6ef4e0ff8ab859fb24ea4c932d7d07 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 4 Dec 2023 19:37:22 +0800 Subject: [PATCH 07/42] add sshe lr Signed-off-by: Yu Wu --- python/fate/components/components/sshe_lr.py | 87 +++--- python/fate/ml/mpc/sshe_lr.py | 286 +++++++++++++++++-- 2 files changed, 299 insertions(+), 74 deletions(-) diff --git a/python/fate/components/components/sshe_lr.py b/python/fate/components/components/sshe_lr.py index 5d3f4f2009..0cfb4bc5dd 100644 --- a/python/fate/components/components/sshe_lr.py +++ b/python/fate/components/components/sshe_lr.py @@ -44,10 +44,12 @@ def train( early_stop: cpn.parameter( type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", - desc="early stopping criterion, choose from {weight_diff, diff, abs, val_metrics}", + desc="early stopping criterion, choose from {weight_diff, diff, abs}", ), - encrypted_reveal: cpn.parameter(type=bool, default=True, - desc="whether reveal encrypted result every epoch, if False, only reveal at the end of training"), + learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), + reveal_every_epoch: cpn.parameter(type=bool, default=False, + desc="whether reveal encrypted result every epoch, " + "if False, only reveal at the end of training"), init_param: cpn.parameter( type=params.init_param(), default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), @@ -56,10 +58,12 @@ def train( threshold: cpn.parameter( type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" ), + reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, + desc="rounds to reveal training loss, " + "only effective if `early_stop` is 'loss'"), train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), - warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True), -): + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True)): logger.info(f"enter sshe lr train") # temp code start init_param = init_param.dict() @@ -73,10 +77,12 @@ def train( output_model, epochs, batch_size, + learning_rate, tol, early_stop, init_param, - encrypted_reveal, + reveal_every_epoch, + reveal_loss_freq, threshold, warm_start_model ) @@ -104,6 +110,12 @@ def cross_validation( type=params.conint(ge=10), default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs}", + ), learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), init_param: cpn.parameter( type=params.init_param(), @@ -113,8 +125,12 @@ def cross_validation( threshold: cpn.parameter( type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" ), - encrypted_reveal: cpn.parameter(type=bool, default=True, - desc="whether reveal encrypted result every epoch, if False, only reveal at the end of training"), + reveal_every_epoch: cpn.parameter(type=bool, default=False, + desc="whether reveal encrypted result every epoch, " + "if False, only reveal at the end of training"), + reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, + desc="rounds to reveal training loss, " + "only effective if `early_stop` is 'loss'"), cv_param: cpn.parameter(type=params.cv_param(), default=params.CVParam(n_splits=5, shuffle=False, random_state=None), desc="cross validation param"), @@ -129,46 +145,29 @@ def cross_validation( i = 0 for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): logger.info(f"enter fold {i}") - if role.is_guest: - module = SSHELogisticRegression( - epochs=epochs, - batch_size=batch_size, - learning_rate=learning_rate, - tol=tol, - early_stop=early_stop, - init_param=init_param, - threshold=threshold, - encrypted_reveal=encrypted_reveal, - ) - module.fit(fold_ctx, train_data, validate_data) - if output_cv_data: + module = SSHELogisticRegression( + epochs=epochs, + batch_size=batch_size, + learning_rate=learning_rate, + tol=tol, + early_stop=early_stop, + init_param=init_param, + threshold=threshold, + reveal_every_epoch=reveal_every_epoch, + reveal_loss_freq=reveal_loss_freq + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + if role.is_guest: sub_ctx = fold_ctx.sub_ctx("predict_train") predict_df = module.predict(sub_ctx, train_data) - """train_predict_result = transform_to_predict_result( - train_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, - data_type="train" - )""" train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) sub_ctx = fold_ctx.sub_ctx("predict_validate") predict_df = module.predict(sub_ctx, validate_data) - """validate_predict_result = transform_to_predict_result( - validate_data, predict_score, module.labels, threshold=module.threshold, is_ovr=module.ovr, - data_type="predict" - )""" validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) next(cv_output_datas).write(df=predict_result) - - # evaluation = evaluate(predicted) - elif role.is_host: - module = SSHELogisticRegression( - epochs=epochs, - batch_size=batch_size, - init_param=init_param, - encrypted_reveal=encrypted_reveal, - ) - module.fit(fold_ctx, train_data, validate_data) - if output_cv_data: + elif role.is_host: sub_ctx = fold_ctx.sub_ctx("predict_train") module.predict(sub_ctx, train_data) sub_ctx = fold_ctx.sub_ctx("predict_validate") @@ -185,11 +184,12 @@ def train_model( output_model, epochs, batch_size, + learning_rate, tol, early_stop, - learning_rate, init_param, - encrypted_reveal, + reveal_every_epoch, + reveal_loss_freq, threshold, input_model ): @@ -209,7 +209,8 @@ def train_model( learning_rate=learning_rate, init_param=init_param, threshold=threshold, - encrypted_reveal=encrypted_reveal + reveal_every_epoch=reveal_every_epoch, + reveal_loss_freq=reveal_loss_freq ) # optimizer = optimizer_factory(optimizer_param) logger.info(f"sshe lr guest start train") diff --git a/python/fate/ml/mpc/sshe_lr.py b/python/fate/ml/mpc/sshe_lr.py index c04e3ee600..47bdd16591 100644 --- a/python/fate/ml/mpc/sshe_lr.py +++ b/python/fate/ml/mpc/sshe_lr.py @@ -1,21 +1,26 @@ import logging -from typing import Union -from fate.arch import Context +import torch + +from fate.arch import Context, dataframe from fate.arch.dataframe import DataFrame from fate.arch.protocol.mpc.nn.sshe.lr_layer import ( SSHELogisticRegressionLayer, SSHELogisticRegressionLossLayer, SSHEOptimizerSGD, ) -from ..abc.module import Module, Model, HeteroModule +from fate.ml.utils import predict_tools +from fate.ml.utils._convergence import converge_func_factory +from fate.ml.utils._model_param import get_initialize_func +from fate.ml.utils._model_param import serialize_param, deserialize_param +from ..abc.module import Module, HeteroModule logger = logging.getLogger(__name__) class SSHELogisticRegression(Module): def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_param, - encrypted_reveal=True, threshold=0.5): + reveal_every_epoch=False, reveal_loss_freq=1, threshold=0.5): self.learning_rate = learning_rate self.epochs = epochs self.batch_size = batch_size @@ -24,16 +29,83 @@ def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_para self.learning_rate = learning_rate self.init_param = init_param self.threshold = threshold - self.encrypted_reveal = encrypted_reveal + self.reveal_every_epoch = reveal_every_epoch + self.reveal_loss_freq = reveal_loss_freq self.estimator = None self.ovr = False self.labels = None def fit(self, ctx: Context, train_data: DataFrame, validate_data=None): - train_data_binarized_label = train_data.label.get_dummies() - label_count = train_data_binarized_label.shape[1] - ctx.hosts.put("label_count", label_count) + if ctx.is_on_guest: + train_data_binarized_label = train_data.label.get_dummies() + label_count = train_data_binarized_label.shape[1] + ctx.hosts.put("label_count", label_count) + labels = [int(label_name.split("_")[1]) for label_name in train_data_binarized_label.columns] + if self.labels is None: + self.labels = sorted(labels) + else: + label_count = ctx.guest.get("label_count") + if label_count > 2 or self.ovr: + logger.info(f"OVR data provided, will train OVR models.") + self.ovr = True + warm_start = True + if self.estimator is None: + self.estimator = {} + warm_start = False + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(label_count): + logger.info(f"start train for {i}th class") + if not warm_start: + single_estimator = SSHELREstimator( + epochs=self.epochs, + batch_size=self.batch_size, + learning_rate=self.learning_rate, + init_param=self.init_param, + reveal_every_epoch=self.reveal_every_epoch, + reveal_loss_freq=self.reveal_loss_freq, + early_stop=self.early_stop, + tol=self.tol + ) + else: + # warm start + logger.info("estimator is not none, will train with warm start") + # single_estimator = self.estimator[self.labels.index(labels[i])] + single_estimator = self.estimator[i] + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + class_train_data = train_data.copy() + class_validate_data = validate_data + if validate_data: + class_validate_data = validate_data.copy() + if ctx.is_on_guest: + class_train_data.label = train_data_binarized_label[train_data_binarized_label.columns[i]] + single_estimator.fit_single_model(class_ctx, class_train_data, class_validate_data) + + self.estimator[i] = single_estimator + + else: + if self.estimator is None: + single_estimator = SSHELREstimator( + epochs=self.epochs, + batch_size=self.batch_size, + learning_rate=self.learning_rate, + init_param=self.init_param, + reveal_every_epoch=self.reveal_every_epoch, + reveal_loss_freq=self.reveal_loss_freq, + early_stop=self.early_stop, + tol=self.tol + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + train_data_fit = train_data.copy() + validate_data_fit = validate_data + if validate_data: + validate_data_fit = validate_data.copy() + single_estimator.fit_single_model(ctx, train_data_fit, validate_data_fit) + self.estimator = single_estimator def get_model(self): all_estimator = {} @@ -49,58 +121,210 @@ def get_model(self): "batch_size": self.batch_size, "learning_rate": self.learning_rate, "init_param": self.init_param, - "optimizer_param": self.optimizer_param, + # "optimizer_param": self.optimizer_param, "labels": self.labels, "ovr": self.ovr, "threshold": self.threshold, - "encrypted_reveal": self.encrypted_reveal, + "reveal_every_epoch": self.reveal_every_epoch, + "reveal_loss_freq": self.reveal_loss_freq, + "tol": self.tol }, } - def from_model(cls, model: Union[dict, Model]): - pass + @classmethod + def from_model(cls, model): + lr = SSHELogisticRegression( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + learning_rate=model["meta"]["learning_rate"], + threshold=model["meta"]["threshold"], + init_param=model["meta"]["init_param"], + reveal_every_epoch=model["meta"]["reveal_every_epoch"], + reveal_loss_freq=model["meta"]["reveal_loss_freq"], + tol=model["meta"]["tol"], + early_stop=model["meta"]["early_stop"] + ) + lr.ovr = model["meta"]["ovr"] + lr.labels = model["meta"]["labels"] + + all_estimator = model["data"]["estimator"] + lr.estimator = {} + if lr.ovr: + for label, d in all_estimator.items(): + estimator = SSHELREstimator( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + reveal_every_epoch=model["meta"]["reveal_every_epoch"], + reveal_loss_freq=model["meta"]["reveal_loss_freq"], + tol=model["meta"]["tol"], + early_stop=model["meta"]["early_stop"] + ) + estimator.restore(d) + lr.estimator[int(label)] = estimator + else: + estimator = SSHELREstimator( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + reveal_every_epoch=model["meta"]["reveal_every_epoch"], + reveal_loss_freq=model["meta"]["reveal_loss_freq"], + tol=model["meta"]["tol"], + early_stop=model["meta"]["early_stop"] + ) + estimator.restore(all_estimator) + lr.estimator = estimator + + return lr + + def predict(self, ctx, test_data) -> DataFrame: + pred_df = test_data.create_frame(with_label=True, with_weight=False) + if self.ovr: + pred_score = test_data.create_frame(with_label=False, with_weight=False) + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): + estimator = self.estimator[i] + pred = estimator.predict(class_ctx, test_data) + pred_score[str(self.labels[i])] = pred + pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)]) + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.MULTI, classes=self.labels + ) + else: + predict_score = self.estimator.predict(ctx, test_data) + pred_df[predict_tools.PREDICT_SCORE] = predict_score + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.BINARY, classes=self.labels, threshold=self.threshold + ) + + return predict_result class SSHELREstimator(HeteroModule): def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate=None, init_param=None, - encrypted_reveal=True): + reveal_every_epoch=True, reveal_loss_freq=3, early_stop=None, tol=None): self.epochs = epochs self.batch_size = batch_size self.optimizer = optimizer self.lr = learning_rate self.init_param = init_param - self.encrypted_reveal = encrypted_reveal + self.reveal_every_epoch = reveal_every_epoch + self.reveal_loss_freq = reveal_loss_freq + self.early_stop = early_stop + self.tol = tol self.w = None self.start_epoch = 0 self.end_epoch = -1 self.is_converged = False self.header = None + self.converge_func = None + if early_stop is not None: + self.converge_func = converge_func_factory(self.early_stop, self.tol) - def fit_binary_model(self, ctx: Context, train_data: DataFrame) -> None: - + def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: DataFrame) -> None: rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank - y = ctx.mpc.cond_call(lambda: train_data.label.as_tensor(), lambda: None, dst=rank_b) - h = train_data.as_tensor() - # generator = torch.Generator().manual_seed(0) + initialize_func = get_initialize_func(**self.init_param) + if self.init_param.get("fit_intercept"): + train_data["intercept"] = 1.0 layer = SSHELogisticRegressionLayer( ctx, - in_features_a=ctx.mpc.option_call(lambda: h.shape[1], dst=rank_a), - in_features_b=ctx.mpc.option_call(lambda: h.shape[1], dst=rank_b), + in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a), + in_features_b=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_b), out_features=1, rank_a=rank_a, rank_b=rank_b, - # generator=generator, + wa_init_fn=initialize_func, + wb_init_fn=initialize_func, ) loss_fn = SSHELogisticRegressionLossLayer(ctx, rank_a=rank_a, rank_b=rank_b) optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) + wa = layer.wa + wb = layer.wb + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=False) + for i, epoch_ctx in ctx.on_iterations.ctxs_range(self.epochs): + epoch_loss = 0 + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, batch_data in epoch_ctx.on_batches.ctxs_zip(batch_loader): + h = batch_data.x + y = ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) + z = layer(h) + loss = loss_fn(z, y) + if i % self.reveal_loss_freq == 0: + epoch_loss += loss.get() + loss.backward() + optimizer.step() + if self.reveal_every_epoch: + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + if ctx.is_on_guest: + if self.early_stop == "weight_diff": + if self.reveal_every_epoch: + wa_p_delta = self.converge_func.compute_weight_diff(wa_p) + w_diff = ctx.guest.put("wa_p_delta", wa_p_delta) + if w_diff < self.tol: + self.is_converged = True + else: + raise ValueError(f"early stop {self.early_stop} is not supported when " + f"reveal_every_epoch is False") + else: + if i % self.reveal_loss_freq == 0: + if epoch_loss is not None: + print(f"epoch {i} loss: {epoch_loss.tolist()}") + epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist()) + if self.early_stop != "weight_diff": + self.is_converged = self.converge_func.is_converge(epoch_loss) + epoch_ctx.hosts.put("converge_flag", self.is_converged) + else: + if self.early_stop == "weight_diff": + if self.reveal_every_epoch: + wb_p_delta = self.converge_func.compute_weight_diff(wb_p) + ctx.guest.put("wb_p_delta", wb_p_delta) + self.is_converged = epoch_ctx.guest.get("converge_flag") + if self.is_converged: + self.end_epoch = i + break + if not self.is_converged: + self.end_epoch = self.epochs + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + if ctx.is_on_host: + self.w = wa_p + else: + self.w = wb_p + + def predict(self, ctx, test_data): + if ctx.is_on_guest: + if self.init_param.get("fit_intercept"): + test_data["intercept"] = 1.0 + X = test_data.values.as_tensor() + # logger.info(f"in predict, w: {self.w}") + pred = torch.matmul(X, self.w.detach()) + h_pred = ctx.hosts.get("h_pred")[0] + pred += h_pred + pred = torch.sigmoid(pred) + + return pred + else: + X = test_data.values.as_tensor() + output = torch.matmul(X, self.w) + ctx.guest.put("h_pred", output) + + def get_model(self): + param = serialize_param(self.w, self.init_param.get("fit_intercept")) + return { + "param": param, + # "optimizer": self.optimizer.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "fit_intercept": self.init_param.get("fit_intercept"), + "header": self.header + } - for i in range(20): - # mpc encrypted [wx] - # to get decrypted wx: z.get_plain_text() - z = layer(h) - loss = loss_fn(z, y) - if i % 3 == 0: - logger.info(f"loss: {loss.get()}") - loss.backward() - optimizer.step() + def restore(self, model): + self.w = deserialize_param(model["param"], model["fit_intercept"]) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] + self.header = model["header"] + self.init_param["fit_intercept"] = model["fit_intercept"] + # self.optimizer.load_state_dict(model["optimizer"]) From 61f5ff2b1665117439b76bd694445b7f011bb31a Mon Sep 17 00:00:00 2001 From: sagewe Date: Tue, 5 Dec 2023 23:30:13 +0800 Subject: [PATCH 08/42] fix stream federation Signed-off-by: sagewe --- .../org/fedai/osx/broker/eggroll/ErJob.java | 48 +++++----- .../org/fedai/osx/broker/eggroll/ErJobIO.java | 84 +++++++++++++++++ .../osx/broker/eggroll/ErPartitioner.java | 65 +++++++++++++ .../fedai/osx/broker/eggroll/ErSerdes.java | 67 +++++++++++++ .../org/fedai/osx/broker/eggroll/ErStore.java | 18 ---- .../osx/broker/eggroll/ErStoreLocator.java | 66 ++++++++----- .../fedai/osx/broker/eggroll/RollPair.java | 2 +- .../osx/broker/eggroll/RollPairContext.java | 3 +- .../grpc/QueuePushReqStreamObserver.java | 15 ++- java/osx/proto/eggroll/command.proto | 5 +- java/osx/proto/eggroll/meta.proto | 60 ++++++++++-- java/osx/proto/eggroll/transfer.proto | 40 ++++++++ .../fate/arch/computing/eggroll/_csession.py | 23 ++--- python/fate/arch/computing/eggroll/_table.py | 2 +- python/fate/arch/federation/_federation.py | 1 + .../arch/federation/eggroll/_federation.py | 22 +++-- python/fate/arch/federation/federation.py | 15 ++- python/fate/arch/launchers/context_helper.py | 18 +++- .../arch/launchers/multiprocess_launcher.py | 3 +- python/fate/arch/utils/trace.py | 94 ++++++++++++++++++- 20 files changed, 519 insertions(+), 132 deletions(-) create mode 100644 java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJobIO.java create mode 100644 java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartitioner.java create mode 100644 java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSerdes.java diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java index 6058d27cd2..6493fe8432 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJob.java @@ -27,13 +27,13 @@ public class ErJob extends BaseProto { String id; String name; - List inputs; - List outputs; + List inputs; + List outputs; List functors; Map options; - public ErJob(String id, String name, List inputs, - List outputs, + public ErJob(String id, String name, List inputs, + List outputs, List functors, Map options) { this.id = id; @@ -51,15 +51,15 @@ public static ErJob parseFromPb(Meta.Job job) { String id = job.getId(); String name = job.getName(); Map options = job.getOptionsMap(); - List inputMeta = job.getInputsList(); - List input = Lists.newArrayList(); + List inputMeta = job.getInputsList(); + List input = Lists.newArrayList(); if (inputMeta != null) { - input = inputMeta.stream().map(ErStore::parseFromPb).collect(Collectors.toList()); + input = inputMeta.stream().map(ErJobIO::parseFromPb).collect(Collectors.toList()); } - List outputMeta = job.getOutputsList(); - List output = Lists.newArrayList(); + List outputMeta = job.getOutputsList(); + List output = Lists.newArrayList(); if (output != null) { - output = outputMeta.stream().map(ErStore::parseFromPb).collect(Collectors.toList()); + output = outputMeta.stream().map(ErJobIO::parseFromPb).collect(Collectors.toList()); } List functors = Lists.newArrayList(); List functorMeta = job.getFunctorsList(); @@ -90,19 +90,19 @@ public void setName(String name) { this.name = name; } - public List getInputs() { + public List getInputs() { return inputs; } - public void setInputs(List inputs) { + public void setInputs(List inputs) { this.inputs = inputs; } - public List getOutputs() { + public List getOutputs() { return outputs; } - public void setOutputs(List outputs) { + public void setOutputs(List outputs) { this.outputs = outputs; } @@ -124,18 +124,12 @@ public void setOptions(Map options) { @Override Meta.Job toProto() { - - return Meta.Job.newBuilder().setId(id).setName(name) + return Meta.Job.newBuilder() + .setId(id) + .setName(name) .addAllFunctors(this.functors.stream().map(ErFunctor::toProto).collect(Collectors.toList())). - addAllInputs(inputs.stream().map(ErStore::toProto).collect(Collectors.toList())).putAllOptions(options).build(); - + addAllInputs(inputs.stream().map(ErJobIO::toProto).collect(Collectors.toList())) + .addAllOutputs(outputs.stream().map(ErJobIO::toProto).collect(Collectors.toList())) + .putAllOptions(options).build(); } -} - - -//case class ErJob(id: String, -// name: String = StringConstants.EMPTY, -// inputs: Array[ErStore], -// outputs: Array[ErStore] = Array(), -// functors: Array[ErFunctor], -// options: Map[String, String] = Map[String, String]()) \ No newline at end of file +} \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJobIO.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJobIO.java new file mode 100644 index 0000000000..872c9bd75a --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErJobIO.java @@ -0,0 +1,84 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.eggroll; + +import com.webank.eggroll.core.meta.Meta; + +public class ErJobIO extends BaseProto { + + ErStore erStore; + ErSerdes keySerdes; + ErSerdes valueSerdes; + ErPartitioner partitioner; + + public ErJobIO(ErStore erStore, ErSerdes keySerdes, ErSerdes valueSerdes, ErPartitioner partitioner) { + this.erStore = erStore; + this.keySerdes = keySerdes; + this.valueSerdes = valueSerdes; + this.partitioner = partitioner; + } + + public static ErJobIO parseFromPb(Meta.JobIO jobIO) { + ErStore erStore = ErStore.parseFromPb(jobIO.getStore()); + ErSerdes key_serdes = ErSerdes.parseFromPb(jobIO.getKeySerdes()); + ErSerdes value_serdes = ErSerdes.parseFromPb(jobIO.getValueSerdes()); + ErPartitioner partitioner = ErPartitioner.parseFromPb(jobIO.getPartitioner()); + ErJobIO erJobIO = new ErJobIO(erStore, key_serdes, value_serdes, partitioner); + return erJobIO; + } + + public ErSerdes getKeySerdes() { + return keySerdes; + } + + public void setKeySerdes(ErSerdes keySerdes) { + this.keySerdes = keySerdes; + } + + public ErSerdes getValueSerdes() { + return valueSerdes; + } + + public void setValueSerdes(ErSerdes valueSerdes) { + this.valueSerdes = valueSerdes; + } + + public ErPartitioner getPartitioner() { + return partitioner; + } + + public void setPartitioner(ErPartitioner partitioner) { + this.partitioner = partitioner; + } + + public ErStore getErStore() { + return erStore; + } + + public void setErStore(ErStore erStore) { + this.erStore = erStore; + } + + @Override + Meta.JobIO toProto() { + Meta.JobIO.Builder builder = Meta.JobIO.newBuilder(); + builder.setStore(erStore.toProto()) + .setKeySerdes(keySerdes.toProto()) + .setValueSerdes(valueSerdes.toProto()) + .setPartitioner(partitioner.toProto()); + return builder.build(); + } +} diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartitioner.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartitioner.java new file mode 100644 index 0000000000..dd1f0bab03 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErPartitioner.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.eggroll; + +import com.google.protobuf.ByteString; +import com.webank.eggroll.core.meta.Meta; + +public class ErPartitioner extends BaseProto { + + Integer t; + ByteString body; + + public ErPartitioner(int t, ByteString body) { + this.t = t; + this.body = body; + } + + public ErPartitioner(int t) { + this(t, ByteString.EMPTY); + } + + public static ErPartitioner parseFromPb(Meta.Partitioner partitioner) { + int t = partitioner.getType(); + ByteString body = partitioner.getBody(); + return new ErPartitioner(t, body); + } + + public int getT() { + return t; + } + + public void setT(Integer t) { + this.t = t; + } + + public ByteString getBody() { + return body; + } + + public void setBody(ByteString body) { + this.body = body; + } + + + @Override + Meta.Partitioner toProto() { + Meta.Partitioner.Builder builder = Meta.Partitioner.newBuilder() + .setType(this.t) + .setBody(this.body); + return builder.build(); + } +} \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSerdes.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSerdes.java new file mode 100644 index 0000000000..5ba258fd21 --- /dev/null +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErSerdes.java @@ -0,0 +1,67 @@ +/* + * Copyright 2019 The FATE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.fedai.osx.broker.eggroll; + +import com.webank.eggroll.core.meta.Meta; + +import com.google.protobuf.ByteString; + +public class ErSerdes extends BaseProto { + + int t; + ByteString body; + + public ErSerdes(int t, ByteString body) { + this.t = t; + this.body = body; + } + + public ErSerdes(int t) { + this(t, ByteString.EMPTY); + } + + + public static ErSerdes parseFromPb(Meta.Serdes serdes) { + Integer t = serdes.getType(); + ByteString body = serdes.getBody(); + return new ErSerdes(t, body); + } + + public int getT() { + return t; + } + + public void setT(Integer t) { + this.t = t; + } + + public ByteString getBody() { + return body; + } + + public void setBody(ByteString body) { + this.body = body; + } + + + @Override + Meta.Serdes toProto() { + Meta.Serdes.Builder builder = Meta.Serdes.newBuilder() + .setType(this.t) + .setBody(this.body); + return builder.build(); + } +} \ No newline at end of file diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java index 105b99dd98..cc475708af 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStore.java @@ -44,24 +44,6 @@ public static ErStore parseFromPb(Meta.Store store) { return erStore; } - public static void main(String[] args) { - -// String namespace ,String name,String storeType, -// int totalPartitions ,String partitioner, -// String serdes - - ErStoreLocator erStoreLocator = new ErStoreLocator("mynamespace", - "myname", "mypath", "mystoreType", 1, "xxxx", "myserdes"); - - List partitions = Lists.newArrayList(); - ErPartition erPartition = new ErPartition(11, null, null, 33); - partitions.add(erPartition); - - ErStore erStore = new ErStore(erStoreLocator, partitions, Maps.newHashMap()); - - - } - public String toString() { return JsonUtil.object2Json(this); } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java index 4ee3975727..dec64ac841 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/ErStoreLocator.java @@ -26,35 +26,41 @@ public class ErStoreLocator extends BaseProto { String name; String path; int totalPartitions; - String partitioner; - String serdes; + + int KeySerdesType; + int ValueSerdesType; + int PartitionerType; public ErStoreLocator(String namespace, String name, String path, String storeType, - int totalPartitions, String partitioner, - String serdes + int totalPartitions, + int KeySerdesType, + int ValueSerdesType, + int PartitionerType ) { + this.storeType = storeType; this.namespace = namespace; this.name = name; this.path = path; - this.storeType = storeType; - this.partitioner = partitioner; this.totalPartitions = totalPartitions; - this.serdes = serdes; - + this.KeySerdesType = KeySerdesType; + this.ValueSerdesType = ValueSerdesType; + this.PartitionerType = PartitionerType; } public static ErStoreLocator parseFromPb(Meta.StoreLocator storeLocator) { - ErStoreLocator erStoreLocator = new ErStoreLocator(storeLocator.getNamespace(), + ErStoreLocator erStoreLocator = new ErStoreLocator( + storeLocator.getNamespace(), storeLocator.getName(), storeLocator.getPath(), storeLocator.getStoreType(), storeLocator.getTotalPartitions(), - storeLocator.getPartitioner(), - storeLocator.getSerdes() + storeLocator.getKeySerdesType(), + storeLocator.getValueSerdesType(), + storeLocator.getPartitionerType() ); return erStoreLocator; @@ -100,20 +106,28 @@ public void setTotalPartitions(int totalPartitions) { this.totalPartitions = totalPartitions; } - public String getPartitioner() { - return partitioner; + public int getKeySerdesType() { + return KeySerdesType; + } + + public void setKeySerdesType(int keySerdesType) { + KeySerdesType = keySerdesType; + } + + public int getValueSerdesType() { + return ValueSerdesType; } - public void setPartitioner(String partitioner) { - this.partitioner = partitioner; + public void setValueSerdesType(int valueSerdesType) { + ValueSerdesType = valueSerdesType; } - public String getSerdes() { - return serdes; + public int getPartitionerType() { + return PartitionerType; } - public void setSerdes(String serdes) { - this.serdes = serdes; + public void setPartitionerType(int partitionerType) { + PartitionerType = partitionerType; } String toPath(String delim) { @@ -131,13 +145,13 @@ Meta.StoreLocator toProto() { return builder.setName(name) .setNamespace(namespace) - .setPartitioner(partitioner) - .setStoreType(storeType). - setPath(path). - setTotalPartitions(totalPartitions). - setSerdes(serdes).build(); - - + .setStoreType(storeType) + .setPath(path) + .setTotalPartitions(totalPartitions) + .setKeySerdesType(KeySerdesType) + .setValueSerdesType(ValueSerdesType) + .setPartitionerType(PartitionerType). + build(); } public String toString() { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java index bb2f25a37d..47dec9bf25 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPair.java @@ -20,7 +20,7 @@ public class RollPair { public static final String PUT_BATCH = "putBatch"; public static final String ROLL_PAIR_URI_PREFIX = "v1/roll-pair"; - public static final String EGG_PAIR_URI_PREFIX = "v1/egg-pair"; + public static final String EGG_PAIR_URI_PREFIX = "v1/eggs-pair"; public static final String RUN_JOB = "runJob"; public static final String RUN_TASK = "runTask"; diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java index 4dac155b60..9bdffd38d8 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/RollPairContext.java @@ -43,8 +43,7 @@ public RollPair load(String namespace, String name, Map options) String storeType = options.getOrDefault(Dict.STORE_TYPE, options.getOrDefault(Dict.STORE_TYPE_SNAKECASE, defaultStoreTypeValue)); int totalPartitions = Integer.parseInt(options.getOrDefault(Dict.TOTAL_PARTITIONS, options.getOrDefault(Dict.TOTAL_PARTITIONS_SNAKECASE, "1"))); ErStoreLocator erStoreLocator = new ErStoreLocator(namespace, name, Dict.EMPTY, storeType, totalPartitions, - options.getOrDefault(Dict.PARTITIONER, PartitionerTypes.BYTESTRING_HASH.name()), - options.getOrDefault(Dict.SERDES, defaultSerdesType)); + 0, 0, 0); ErStore store = new ErStore(erStoreLocator, Lists.newArrayList(), options); ErStore loaded = erSession.clusterManagerClient.getOrCreateStore(store); return new RollPair(loaded, this, Maps.newHashMap()); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java index 591232a662..b468c01f0e 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java @@ -208,8 +208,8 @@ private void initEggroll(OsxContext context, Proxy.Packet firstRequest) { ErJob job = new ErJob( jobId, RollPair.PUT_BATCH, - Lists.newArrayList(rp.getStore()), - Lists.newArrayList(rp.getStore()), + Lists.newArrayList(new ErJobIO(rp.getStore(), new ErSerdes(0), new ErSerdes(0), new ErPartitioner(0))), + Lists.newArrayList(new ErJobIO(rp.getStore(), new ErSerdes(0), new ErSerdes(0), new ErPartitioner(0))), Lists.newArrayList(), jobOptions); @@ -220,18 +220,17 @@ private void initEggroll(OsxContext context, Proxy.Packet firstRequest) { job); Future commandFuture = RollPairContext.executor.submit(() -> { - CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); - Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); - long begin = System.currentTimeMillis(); try { + CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); + Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); Meta.Task taskMeta = Meta.Task.parseFrom(commandResponse.getResultsList().get(0)); ErTask erTask = ErTask.parseFromPb(taskMeta); long now = System.currentTimeMillis(); return erTask; - } catch (InvalidProtocolBufferException igore) { - + } catch (Exception e) { + logger.error("submit putBatch task error", e); + throw e; } - return null; }); RouterInfo routerInfo = new RouterInfo(); routerInfo.setProtocol(Protocol.grpc); diff --git a/java/osx/proto/eggroll/command.proto b/java/osx/proto/eggroll/command.proto index 93ee04a5a2..9f19c7e86f 100644 --- a/java/osx/proto/eggroll/command.proto +++ b/java/osx/proto/eggroll/command.proto @@ -29,8 +29,11 @@ message CommandResponse { string id = 1; CommandRequest request = 2; repeated bytes results = 3; + string code = 4; + string msg = 5; } service CommandService { rpc call (CommandRequest) returns (CommandResponse); -} \ No newline at end of file + rpc callStream (stream CommandRequest) returns (stream CommandResponse); +} diff --git a/java/osx/proto/eggroll/meta.proto b/java/osx/proto/eggroll/meta.proto index c97667ead7..b8ff679bf2 100644 --- a/java/osx/proto/eggroll/meta.proto +++ b/java/osx/proto/eggroll/meta.proto @@ -32,6 +32,7 @@ message ServerNode { Endpoint endpoint = 4; string nodeType = 5; string status = 6; + repeated Resource resources = 7; } message ServerCluster { @@ -67,6 +68,15 @@ message Functor { bytes body = 3; map options = 4; } +message Partitioner { + int32 type = 1; + bytes body = 2; +} + +message Serdes { + int32 type = 1; + bytes body = 2; +} message Pair { bytes key = 1; @@ -79,13 +89,14 @@ message PairBatch { message StoreLocator { int64 id = 1; - string storeType = 2; + string store_type = 2; string namespace = 3; string name = 4; string path = 5; - int32 totalPartitions = 6; - string partitioner = 7; - string serdes = 8; + int32 total_partitions = 6; + int32 key_serdes_type = 7; + int32 value_serdes_type = 8; + int32 partitioner_type = 9; } message Store { @@ -95,7 +106,7 @@ message Store { } message StoreList { - repeated Store stores = 1; + repeated Store stores = 1; } message Partition { @@ -110,12 +121,18 @@ message CallInfo { string callSeq = 1; } -// todo: add job / task status +message JobIO { + Store store = 1; + Serdes key_serdes = 2; + Serdes value_serdes = 3; + Partitioner partitioner = 4; +} +// todo: add / task status message Job { string id = 1; string name = 2; - repeated Store inputs = 3; - repeated Store outputs = 4; + repeated JobIO inputs = 3; + repeated JobIO outputs = 4; repeated Functor functors = 5; map options = 6; } @@ -135,4 +152,29 @@ message SessionMeta { string tag = 4; repeated Processor processors = 5; map options = 6; -} \ No newline at end of file +} + +message ResourceAllocation { + int64 serverNodeId = 1; + string status = 2; + string sessionId = 3; + string operateType = 4; + repeated Resource resources = 5; + +} + +message Resource { + string type = 1; + int64 total = 2; + int64 used = 3; + int64 allocated = 4; +} + +message NodeHeartbeat { + uint64 id = 1; + ServerNode node = 2; + string code = 3; + string msg = 4; + repeated int32 gpuProcessors=5; + repeated int32 cpuProcessors=6; +} diff --git a/java/osx/proto/eggroll/transfer.proto b/java/osx/proto/eggroll/transfer.proto index 60d3c47776..408b6983a3 100644 --- a/java/osx/proto/eggroll/transfer.proto +++ b/java/osx/proto/eggroll/transfer.proto @@ -44,3 +44,43 @@ service TransferService { rpc sendRecv (stream TransferBatch) returns (stream TransferBatch); } +message RollSitePullGetHeaderRequest { + string tag = 1; + double timeout = 2; +} + +message RollSitePullGetHeaderResponse { + RollSiteHeader header = 1; +} + +message RollSitePullGetPartitionStatusRequest { + string tag = 1; + double timeout = 2; +} + +message RollSitePullGetPartitionStatusResponse { + message IntKeyIntValuePair { + int64 key = 1; + int64 value = 2; + } + message RollSitePullGetPartitionStatusResponseStatus { + string tag = 1; + bool is_finished = 2; + int64 total_batches = 3; + repeated IntKeyIntValuePair batch_seq_to_pair_counter = 4; + int64 total_streams = 5; + repeated IntKeyIntValuePair stream_seq_to_pair_counter = 6; + repeated IntKeyIntValuePair stream_seq_to_batch_seq = 7; + int64 total_pairs = 8; + string data_type = 9; + } + int64 partition_id = 1; + RollSitePullGetPartitionStatusResponseStatus status = 2; +} + +message RollSitePullClearStatusRequest { + string tag = 1; +} + +message RollSitePullClearStatusResponse { +} diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/eggroll/_csession.py index 65c5313424..854d1be95a 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/eggroll/_csession.py @@ -17,14 +17,13 @@ import logging from fate.arch.computing.table import KVTableContext - -from ...unify import URI, uuid -from .._profile import computing_profile from ._table import Table +from .._profile import computing_profile +from ...unify import URI, uuid try: - from eggroll.core.session import session_init - from eggroll.roll_pair.roll_pair import runtime_init + from eggroll.session import session_init + from eggroll.computing import runtime_init except ImportError: raise EnvironmentError("eggroll not found in pythonpath") @@ -35,13 +34,9 @@ class CSession(KVTableContext): def __init__(self, session_id, options: dict = None): if options is None: options = {} - if "eggroll.session.deploy.mode" not in options: - options["eggroll.session.deploy.mode"] = "cluster" - if "eggroll.rollpair.inmemory_output" not in options: - options["eggroll.rollpair.inmemory_output"] = True - self._rp_session = session_init(session_id=session_id, options=options) - self._rpc = runtime_init(session=self._rp_session) - self._session_id = self._rp_session.get_session_id() + self._eggroll_session = session_init(session_id=session_id, options=options) + self._rpc = runtime_init(session=self._eggroll_session) + self._session_id = self._eggroll_session.get_session_id() def get_rpc(self): return self._rpc @@ -104,10 +99,10 @@ def cleanup(self, name, namespace): self._rpc.cleanup(name=name, namespace=namespace) def stop(self): - return self._rp_session.stop() + return self._eggroll_session.stop() def kill(self): - return self._rp_session.kill() + return self._eggroll_session.kill() def destroy(self): try: diff --git a/python/fate/arch/computing/eggroll/_table.py b/python/fate/arch/computing/eggroll/_table.py index 690bdb7521..25ef8fd1e7 100644 --- a/python/fate/arch/computing/eggroll/_table.py +++ b/python/fate/arch/computing/eggroll/_table.py @@ -21,7 +21,7 @@ from ...unify import URI from .._type import ComputingEngine from ..table import KVTable -from eggroll.roll_pair.roll_pair import RollPair +from eggroll.computing import RollPair LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/_federation.py b/python/fate/arch/federation/_federation.py index ee354c3aab..024d8f9e12 100644 --- a/python/fate/arch/federation/_federation.py +++ b/python/fate/arch/federation/_federation.py @@ -685,6 +685,7 @@ def _partition_receive( mq, conf: dict, ): + topic_pair = topic_infos[index][1] channel_info = self._get_channel( topic_pair=topic_pair, diff --git a/python/fate/arch/federation/eggroll/_federation.py b/python/fate/arch/federation/eggroll/_federation.py index 8f4cbef5bd..71378c872c 100644 --- a/python/fate/arch/federation/eggroll/_federation.py +++ b/python/fate/arch/federation/eggroll/_federation.py @@ -21,10 +21,9 @@ import typing from typing import List -from eggroll.roll_pair.roll_pair import RollPair -from eggroll.roll_site.roll_site import RollSiteContext +from eggroll.computing import RollPair +from eggroll.federation import RollSiteContext from fate.arch.federation.federation import Federation, PartyMeta - from ...computing.eggroll import Table logger = logging.getLogger(__name__) @@ -40,15 +39,14 @@ def __init__( proxy_endpoint, ): super().__init__(rs_session_id, party, parties) + proxy_endpoint_host, proxy_endpoint_port = proxy_endpoint.split(":") self._rp_ctx = rp_ctx self._rsc = RollSiteContext( rs_session_id, rp_ctx=rp_ctx, - options={ - "self_role": party[0], - "self_party_id": party[1], - "proxy_endpoint": proxy_endpoint, - }, + party=party, + proxy_endpoint_host=proxy_endpoint_host.strip(), + proxy_endpoint_port=int(proxy_endpoint_port.strip()), ) def _pull_table( @@ -91,12 +89,16 @@ def _pull_bytes(self, name: str, tag: str, parties: List[PartyMeta]): def _push_table(self, table: Table, name: str, tag: str, parties: List[PartyMeta]): rs = self._rsc.load(name=name, tag=tag) futures = rs.push_rp(table._rp, parties=parties) - # concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + for future in done: + future.result() def _push_bytes(self, v: bytes, name: str, tag: str, parties: List[PartyMeta]): rs = self._rsc.load(name=name, tag=tag) futures = rs.push_bytes(v, parties=parties) - # concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + for future in done: + future.result() def destroy(self): self._rp_ctx.cleanup(name="*", namespace=self._session_id) diff --git a/python/fate/arch/federation/federation.py b/python/fate/arch/federation/federation.py index 35420e0de6..b4f4b088bb 100644 --- a/python/fate/arch/federation/federation.py +++ b/python/fate/arch/federation/federation.py @@ -18,7 +18,12 @@ from typing import List from fate.arch.abc import PartyMeta -from fate.arch.utils.trace import federation_auto_trace +from fate.arch.utils.trace import ( + federation_push_table_trace, + federation_pull_table_trace, + federation_push_bytes_trace, + federation_pull_bytes_trace, +) from ._gc import GarbageCollector if typing.TYPE_CHECKING: @@ -90,7 +95,7 @@ def _push_bytes( ): raise NotImplementedError(f"push bytes is not supported in {self.__class__.__name__}") - @federation_auto_trace + @federation_push_table_trace def push_table( self, table: "KVTable", @@ -111,7 +116,7 @@ def push_table( parties=parties, ) - @federation_auto_trace + @federation_push_bytes_trace def push_bytes( self, v: bytes, @@ -131,7 +136,7 @@ def push_bytes( parties=parties, ) - @federation_auto_trace + @federation_pull_table_trace def pull_table( self, name: str, @@ -152,7 +157,7 @@ def pull_table( self._get_gc.register_clean_action(name, tag, table, "destroy", {}) return tables - @federation_auto_trace + @federation_pull_bytes_trace def pull_bytes( self, name: str, diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index bb38f6f614..0c63ac11e3 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -63,19 +63,27 @@ def init_eggroll_context(): from fate.arch.computing.eggroll import CSession from fate.arch.federation.osx import OSXFederation + from fate.arch.federation.eggroll import EggrollFederation from fate.arch.context import Context args = HfArgumentParser(LauncherEggrollContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[0] parties = get_parties(args.parties) party = parties[args.rank] computing_session = CSession(session_id=args.csession_id) - federation_session = OSXFederation.from_conf( - federation_session_id=args.federation_session_id, - computing_session=computing_session, + federation_session = EggrollFederation( + rp_ctx=computing_session.get_rpc(), + rs_session_id=args.federation_session_id, party=party, parties=parties, - host=args.host, - port=args.port, + proxy_endpoint=f"{args.host}:{args.port}", ) + # federation_session = OSXFederation.from_conf( + # federation_session_id=args.federation_session_id, + # computing_session=computing_session, + # party=party, + # parties=parties, + # host=args.host, + # port=args.port, + # ) context = Context(computing=computing_session, federation=federation_session) return context diff --git a/python/fate/arch/launchers/multiprocess_launcher.py b/python/fate/arch/launchers/multiprocess_launcher.py index abfaf8db18..7d9d0e9eec 100644 --- a/python/fate/arch/launchers/multiprocess_launcher.py +++ b/python/fate/arch/launchers/multiprocess_launcher.py @@ -138,6 +138,7 @@ def _run_process( parties = args.get_parties() party = parties[args.rank] csession_id = f"{args.federation_session_id}_{party[0]}_{party[1]}" + argv.extend(["--csession_id", csession_id]) # set up logging set_up_logging(args.rank, args.log_level) @@ -188,7 +189,7 @@ def wait(self) -> int: def terminate(self): self.safe_to_exit.set() - time.sleep(5) # wait for 1 second to let all processes has a chance to exit + time.sleep(10) # wait for 1 second to let all processes has a chance to exit for process in self.processes: if process.is_alive(): process.terminate() diff --git a/python/fate/arch/utils/trace.py b/python/fate/arch/utils/trace.py index 6202438417..149483ce11 100644 --- a/python/fate/arch/utils/trace.py +++ b/python/fate/arch/utils/trace.py @@ -1,8 +1,17 @@ import functools +import logging import os +import time +import typing +from typing import List from opentelemetry import trace, context +if typing.TYPE_CHECKING: + from fate.arch.federation.federation import PartyMeta + from fate.arch.computing.table import KVTable + +logger = logging.getLogger(__name__) _ENABLE_TRACING = None _ENABLE_TRACING_DEFAULT = True @@ -51,11 +60,16 @@ def wrapper(*args, **kwargs): def _trace_func(func, args, kwargs, span_name=None): - if not _is_tracing_enabled(): - return func(*args, **kwargs) - module_name = func.__module__ qualname = func.__qualname__ + + if not _is_tracing_enabled(): + start = time.time() + out = func(*args, **kwargs) + elapsed = time.time() - start + logger.debug(f"{module_name}:{qualname} tasks: {elapsed}") + return out + if span_name is None: span_name = qualname tracer = get_tracer(module_name) @@ -83,9 +97,81 @@ def extract_carrier(carrier): return TraceContextTextMapPropagator().extract(carrier) +def federation_push_table_trace(func): + @functools.wraps(func) + def wrapper( + self, + table: "KVTable", + name: str, + tag: str, + parties: List["PartyMeta"], + ): + logger.debug(f"function {func.__qualname__} is calling on name={name}, tag={tag}, parties={parties}") + out = func(self, table, name, tag, parties) + logger.debug(f"function {func.__qualname__} is called on name={name}, tag={tag}, parties={parties}") + return out + + return wrapper + + +def federation_push_bytes_trace(func): + @functools.wraps(func) + def wrapper( + self, + v: bytes, + name: str, + tag: str, + parties: List["PartyMeta"], + ): + logger.debug(f"function {func.__qualname__} is calling on name={name}, tag={tag}, parties={parties}") + out = func(self, v, name, tag, parties) + logger.debug(f"function {func.__qualname__} is called on name={name}, tag={tag}, parties={parties}") + return out + + return wrapper + + +def federation_pull_table_trace(func): + @functools.wraps(func) + def wrapper( + self, + name: str, + tag: str, + parties: List["PartyMeta"], + ): + logger.debug(f"function {func.__qualname__} is calling on name={name}, tag={tag}, parties={parties}") + out = func(self, name, tag, parties) + logger.debug(f"function {func.__qualname__} is called on name={name}, tag={tag}, parties={parties}") + return out + + return wrapper + + +def federation_pull_bytes_trace(func): + @functools.wraps(func) + def wrapper( + self, + name: str, + tag: str, + parties: List["PartyMeta"], + ): + logger.debug(f"function {func.__qualname__} is calling on name={name}, tag={tag}, parties={parties}") + out = func(self, name, tag, parties) + logger.debug(f"function {func.__qualname__} is called on name={name}, tag={tag}, parties={parties}") + return out + + return wrapper + + def federation_auto_trace(func): if not _is_tracing_enabled(): - return func + + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + logger.error(f"function {func.__qualname__} is called on {args} {kwargs}") + return out + + return wrapper @functools.wraps(func) def wrapper(*args, **kwargs): From aad302f3de05f9826f32ac7d6d7e9c10b697f148 Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 6 Dec 2023 16:05:55 +0800 Subject: [PATCH 09/42] fix stream federation and add global federation init Signed-off-by: sagewe --- .../fate/arch/computing/eggroll/_csession.py | 12 +- python/fate/arch/config/_config.py | 15 +- python/fate/arch/federation/__init__.py | 3 +- python/fate/arch/federation/_builder.py | 171 ++++++++++++++++++ .../arch/federation/eggroll/_federation.py | 27 +-- .../fate/arch/federation/osx/_federation.py | 18 +- python/fate/arch/launchers/context_helper.py | 55 +++--- .../fate/components/core/_load_federation.py | 115 ++++++------ .../fate/components/core/spec/federation.py | 1 + 9 files changed, 300 insertions(+), 117 deletions(-) create mode 100644 python/fate/arch/federation/_builder.py diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/eggroll/_csession.py index 854d1be95a..ab64350d50 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/eggroll/_csession.py @@ -31,10 +31,18 @@ class CSession(KVTableContext): - def __init__(self, session_id, options: dict = None): + def __init__( + self, session_id, options: dict = None, config=None, config_options=None, config_properties_file=None + ): if options is None: options = {} - self._eggroll_session = session_init(session_id=session_id, options=options) + self._eggroll_session = session_init( + session_id=session_id, + options=options, + config=config, + config_options=config_options, + config_properties_file=config_properties_file, + ) self._rpc = runtime_init(session=self._eggroll_session) self._session_id = self._eggroll_session.get_session_id() diff --git a/python/fate/arch/config/_config.py b/python/fate/arch/config/_config.py index a3d325a4ae..4381ba8fa2 100644 --- a/python/fate/arch/config/_config.py +++ b/python/fate/arch/config/_config.py @@ -7,8 +7,8 @@ import os from contextlib import contextmanager -from ruamel import yaml from omegaconf import OmegaConf +from ruamel import yaml class Config(object): @@ -57,7 +57,12 @@ def temp_override(self, override_dict): finally: self.config = old_config - -if __name__ == "__main__": - config = Config() - print(config.mpc.provider) + def get_option(self, options, key, default=...): + if key in options: + return options[key] + elif self.config.get(key, None) is not None: + return self.config[key] + elif default is ...: + raise ValueError(f"{key} not in {options} or {self.config}") + else: + return default diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index 8e924d12a0..4aa278575d 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ._builder import FederationBuilder, FederationMode from ._type import FederationDataType -__all__ = ["FederationDataType"] +__all__ = ["FederationDataType", "FederationBuilder", "FederationMode"] diff --git a/python/fate/arch/federation/_builder.py b/python/fate/arch/federation/_builder.py new file mode 100644 index 0000000000..9afdce68d3 --- /dev/null +++ b/python/fate/arch/federation/_builder.py @@ -0,0 +1,171 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from enum import Enum + +from fate.arch.abc import PartyMeta +from fate.arch.config import cfg + + +class FederationType(Enum): + STANDALONE = "standalone" + OSX = "osx" + RABBITMQ = "rabbitmq" + PULSAR = "pulsar" + + @classmethod + def from_str(cls, s: str): + for t in cls: + if t.value == s: + return t + raise ValueError(f"{s} not in {cls}") + + def __str__(self): + return self.value + + +class FederationMode(Enum): + STREAM = "stream" + MESSAGE_QUEUE = "message_queue" + + @classmethod + def from_str(cls, s: str): + if isinstance(s, cls): + return s + for t in cls: + if t.value == s: + return t + raise ValueError(f"{s} not in {cls}") + + +class FederationBuilder: + def __init__( + self, + federation_id: str, + party: PartyMeta, + parties: typing.List[PartyMeta], + ): + self._federation_id = federation_id + self._party = party + self._parties = parties + + def build(self, computing_session, t: FederationType, conf: dict): + if t == FederationType.STANDALONE: + return self.build_standalone(computing_session) + elif t == FederationType.OSX: + host = cfg.get_option(conf, "federation.osx.host") + port = cfg.get_option(conf, "federation.osx.port") + mode = FederationMode.from_str(cfg.get_option(conf, "federation.osx.mode", FederationMode.MESSAGE_QUEUE)) + return self.build_osx(computing_session, host=host, port=port, mode=mode) + elif t == FederationType.RABBITMQ: + host = cfg.get_option(conf, "federation.rabbitmq.host") + port = cfg.get_option(conf, "federation.rabbitmq.port") + options = cfg.get_option(conf, "federation.rabbitmq") + return self.build_rabbitmq(computing_session, host=host, port=port, options=options) + elif t == FederationType.PULSAR: + host = cfg.get_option(conf, "federation.pulsar.host") + port = cfg.get_option(conf, "federation.pulsar.port") + options = cfg.get_option(conf, "federation.pulsar") + return self.build_pulsar(computing_session, host=host, port=port, options=options) + else: + raise ValueError(f"{t} not in {FederationType}") + + def build_standalone(self, computing_session): + from fate.arch.federation.standalone import StandaloneFederation + + return StandaloneFederation( + standalone_session=computing_session, + federation_session_id=self._federation_id, + party=self._party, + parties=self._parties, + ) + + def build_osx( + self, computing_session, host: str, port: int, mode=FederationMode.MESSAGE_QUEUE, options: dict = None + ): + if options is None: + options = {} + if mode == FederationMode.MESSAGE_QUEUE: + from fate.arch.federation.osx import OSXFederation + + return OSXFederation.from_conf( + federation_session_id=self._federation_id, + computing_session=computing_session, + party=self._party, + parties=self._parties, + host=host, + port=port, + max_message_size=options.get("max_message_size"), + ) + else: + from fate.arch.computing.eggroll import CSession + from fate.arch.federation.eggroll import EggrollFederation + + if not isinstance(computing_session, CSession): + raise RuntimeError( + f"Eggroll federation type requires Eggroll computing type, `{type(computing_session)}` found" + ) + + return EggrollFederation( + computing_session=computing_session, + federation_session_id=self._federation_id, + party=self._party, + parties=self._parties, + proxy_endpoint=f"{host}:{port}", + ) + + def build_rabbitmq(self, computing_session, host: str, port: int, options: dict): + from fate.arch.federation.rabbitmq import RabbitmqFederation + + return RabbitmqFederation.from_conf( + federation_session_id=self._federation_id, + computing_session=computing_session, + party=self._party, + parties=self._parties, + host=host, + port=port, + route_table=options["route_table"], + mng_port=options["mng_port"], + base_user=options["base_user"], + base_password=options["base_password"], + mode=options["mode"], + max_message_size=options["max_message_size"], + rabbitmq_run=options["rabbitmq_run"], + connection=options["connection"], + ) + + def build_pulsar(self, computing_session, host: str, port: int, options: dict): + from fate.arch.federation.pulsar import PulsarFederation + + return PulsarFederation.from_conf( + federation_session_id=self._federation_id, + computing_session=computing_session, + party=self._party, + parties=self._parties, + host=host, + port=port, + route_table=options["route_table"], + mode=options["mode"], + mng_port=options["mng_port"], + base_user=options["base_user"], + base_password=options["base_password"], + max_message_size=options["max_message_size"], + topic_ttl=options["topic_ttl"], + cluster=options["cluster"], + tenant=options["tenant"], + pulsar_run=options["pulsar_run"], + connection=options["connection"], + ) diff --git a/python/fate/arch/federation/eggroll/_federation.py b/python/fate/arch/federation/eggroll/_federation.py index 71378c872c..f7eabe0098 100644 --- a/python/fate/arch/federation/eggroll/_federation.py +++ b/python/fate/arch/federation/eggroll/_federation.py @@ -28,22 +28,25 @@ logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + from fate.arch.computing.eggroll import CSession + class EggrollFederation(Federation): def __init__( self, - rp_ctx, - rs_session_id, + computing_session: "CSession", + federation_session_id, party: PartyMeta, parties: List[PartyMeta], proxy_endpoint, ): - super().__init__(rs_session_id, party, parties) + super().__init__(federation_session_id, party, parties) proxy_endpoint_host, proxy_endpoint_port = proxy_endpoint.split(":") - self._rp_ctx = rp_ctx + self._rp_ctx = computing_session.get_rpc() self._rsc = RollSiteContext( - rs_session_id, - rp_ctx=rp_ctx, + federation_session_id, + rp_ctx=self._rp_ctx, party=party, proxy_endpoint_host=proxy_endpoint_host.strip(), proxy_endpoint_port=int(proxy_endpoint_port.strip()), @@ -89,16 +92,16 @@ def _pull_bytes(self, name: str, tag: str, parties: List[PartyMeta]): def _push_table(self, table: Table, name: str, tag: str, parties: List[PartyMeta]): rs = self._rsc.load(name=name, tag=tag) futures = rs.push_rp(table._rp, parties=parties) - done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) - for future in done: - future.result() + # done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + # for future in done: + # future.result() def _push_bytes(self, v: bytes, name: str, tag: str, parties: List[PartyMeta]): rs = self._rsc.load(name=name, tag=tag) futures = rs.push_bytes(v, parties=parties) - done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) - for future in done: - future.result() + # done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION) + # for future in done: + # future.result() def destroy(self): self._rp_ctx.cleanup(name="*", namespace=self._session_id) diff --git a/python/fate/arch/federation/osx/_federation.py b/python/fate/arch/federation/osx/_federation.py index 0c39f67b68..5ba0bbf6c2 100644 --- a/python/fate/arch/federation/osx/_federation.py +++ b/python/fate/arch/federation/osx/_federation.py @@ -19,10 +19,8 @@ from fate.arch.abc import PartyMeta from fate.arch.federation.osx import osx_pb2 - -from .._federation import FederationBase -from .._nretry import nretry from ._mq_channel import MQChannel +from .._federation import FederationBase LOGGER = getLogger(__name__) # default message max size in bytes = 1MB @@ -67,7 +65,7 @@ def from_conf( mq = MQ(host, port) return OSXFederation( - session_id=federation_session_id, + federation_session_id=federation_session_id, computing_session=computing_session, party=party, parties=parties, @@ -76,10 +74,16 @@ def from_conf( ) def __init__( - self, session_id, computing_session, party: PartyMeta, parties: typing.List[PartyMeta], max_message_size, mq + self, + federation_session_id, + computing_session, + party: PartyMeta, + parties: typing.List[PartyMeta], + max_message_size, + mq, ): super().__init__( - session_id=session_id, + session_id=federation_session_id, computing_session=computing_session, party=party, parties=parties, @@ -182,7 +186,7 @@ def _get_consume_message(self, channel_info): response = channel_info.consume() # LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}, response={response}") if response.code == "E0000000601": - raise LookupError(f"{response}") + raise LookupError(f"{response}") message = osx_pb2.Message() message.ParseFromString(response.payload) # offset = response.metadata["MessageOffSet"] diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index 0c63ac11e3..9be2dab636 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -1,6 +1,7 @@ import os -from typing import List from dataclasses import dataclass, field +from typing import List + from .argparser import HfArgumentParser, get_parties @@ -14,13 +15,15 @@ class LauncherStandaloneContextArgs: @dataclass -class LauncherEggrollContextArgs: +class LauncherDistributedContextArgs: federation_session_id: str = field() parties: List[str] = field() rank: int = field() + config_properties: str = field() csession_id: str = field(default=None) host: str = field(default="127.0.0.1") port: int = field(default=9377) + federation_mode: str = field(default="message_queue") @dataclass @@ -30,10 +33,10 @@ class LauncherContextArguments: def init_context(): args = HfArgumentParser(LauncherContextArguments).parse_known_args()[0] - if args.context_type == "standalone": + if args.context_type == "local": return init_standalone_context() - elif args.context_type == "eggroll": - return init_eggroll_context() + elif args.context_type == "cluster": + return init_distributed_context() else: raise ValueError(f"unknown context type: {args.context_type}") @@ -41,7 +44,7 @@ def init_context(): def init_standalone_context(): from fate.arch.utils.paths import get_base_dir from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.federation import FederationBuilder from fate.arch.context import Context args = HfArgumentParser(LauncherStandaloneContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ @@ -54,36 +57,36 @@ def init_standalone_context(): computing_session = CSession(session_id=args.csession_id, data_dir=data_dir) parties = get_parties(args.parties) party = parties[args.rank] - federation_session = StandaloneFederation(computing_session, args.federation_session_id, party, parties) + federation_session = FederationBuilder( + federation_id=args.federation_session_id, party=party, parties=parties + ).build_standalone( + computing_session, + ) context = Context(computing=computing_session, federation=federation_session) return context -def init_eggroll_context(): +def init_distributed_context(): + from fate.arch.federation import FederationBuilder, FederationMode from fate.arch.computing.eggroll import CSession - from fate.arch.federation.osx import OSXFederation - from fate.arch.federation.eggroll import EggrollFederation from fate.arch.context import Context - args = HfArgumentParser(LauncherEggrollContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[0] + args = HfArgumentParser(LauncherDistributedContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ + 0 + ] parties = get_parties(args.parties) party = parties[args.rank] - computing_session = CSession(session_id=args.csession_id) - federation_session = EggrollFederation( - rp_ctx=computing_session.get_rpc(), - rs_session_id=args.federation_session_id, - party=party, - parties=parties, - proxy_endpoint=f"{args.host}:{args.port}", + computing_session = CSession(session_id=args.csession_id, config_properties_file=args.config_properties) + federation_mode = FederationMode.from_str(args.federation_mode) + + federation_session = FederationBuilder( + federation_id=args.federation_session_id, party=party, parties=parties + ).build_osx( + computing_session=computing_session, + host=args.host, + port=args.port, + mode=federation_mode, ) - # federation_session = OSXFederation.from_conf( - # federation_session_id=args.federation_session_id, - # computing_session=computing_session, - # party=party, - # parties=parties, - # host=args.host, - # port=args.port, - # ) context = Context(computing=computing_session, federation=federation_session) return context diff --git a/python/fate/components/core/_load_federation.py b/python/fate/components/core/_load_federation.py index 4419b36996..f98fabe959 100644 --- a/python/fate/components/core/_load_federation.py +++ b/python/fate/components/core/_load_federation.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. def load_federation(federation, computing): + from fate.arch.federation import FederationBuilder, FederationMode from fate.components.core.spec.federation import ( OSXFederationSpec, PulsarFederationSpec, @@ -21,90 +22,76 @@ def load_federation(federation, computing): StandaloneFederationSpec, ) - if isinstance(federation, StandaloneFederationSpec): - from fate.arch.federation.standalone import StandaloneFederation + builder = FederationBuilder( + federation_id=federation.metadata.federation_id, + party=federation.metadata.parties.local.tuple(), + parties=[p.tuple() for p in federation.metadata.parties.parties], + ) - return StandaloneFederation( - computing, - federation.metadata.federation_id, - federation.metadata.parties.local.tuple(), - [p.tuple() for p in federation.metadata.parties.parties], + if isinstance(federation, StandaloneFederationSpec): + return builder.build_standalone( + computing_session=computing, ) - if isinstance(federation, RollSiteFederationSpec): - from fate.arch.computing.eggroll import CSession - from fate.arch.federation.eggroll import EggrollFederation - - if not isinstance(computing, CSession): - raise RuntimeError(f"Eggroll federation type requires Eggroll computing type, `{type(computing)}` found") - - return EggrollFederation( - rp_ctx=computing.get_rpc(), - rs_session_id=federation.metadata.federation_id, - party=federation.metadata.parties.local.tuple(), - parties=[p.tuple() for p in federation.metadata.parties.parties], - proxy_endpoint=f"{federation.metadata.rollsite_config.host}:{federation.metadata.rollsite_config.port}", + if isinstance(federation, (OSXFederationSpec, RollSiteFederationSpec)): + if isinstance(federation, OSXFederationSpec): + mode = FederationMode.from_str(federation.metadata.osx_config.mode) + options = dict(max_message_size=federation.metadata.osx_config.max_message_size) + else: + mode = FederationMode.STREAM + options = {} + return builder.build_osx( + computing_session=computing, + host=federation.metadata.osx_config.host, + port=federation.metadata.osx_config.port, + mode=mode, + options=options, ) - if isinstance(federation, RabbitMQFederationSpec): - from fate.arch.federation.rabbitmq import RabbitmqFederation - - return RabbitmqFederation.from_conf( - federation_session_id=federation.metadata.federation_id, + return builder.build_rabbitmq( computing_session=computing, - party=federation.metadata.parties.local.tuple(), - parties=[p.tuple() for p in federation.metadata.parties.parties], - route_table={k: v.dict() for k, v in federation.metadata.route_table.items()}, host=federation.metadata.rabbitmq_config.host, port=federation.metadata.rabbitmq_config.port, - mng_port=federation.metadata.rabbitmq_config.mng_port, - base_user=federation.metadata.rabbitmq_config.user, - base_password=federation.metadata.rabbitmq_config.password, - mode=federation.metadata.rabbitmq_config.mode, - max_message_size=federation.metadata.rabbitmq_config.max_message_size, - rabbitmq_run=federation.metadata.rabbitmq_run, - connection=federation.metadata.connection, + options=dict( + route_table={k: v.dict() for k, v in federation.metadata.route_table.items()}, + host=federation.metadata.rabbitmq_config.host, + port=federation.metadata.rabbitmq_config.port, + mng_port=federation.metadata.rabbitmq_config.mng_port, + base_user=federation.metadata.rabbitmq_config.user, + base_password=federation.metadata.rabbitmq_config.password, + mode=federation.metadata.rabbitmq_config.mode, + max_message_size=federation.metadata.rabbitmq_config.max_message_size, + rabbitmq_run=federation.metadata.rabbitmq_run, + connection=federation.metadata.connection, + ), ) if isinstance(federation, PulsarFederationSpec): - from fate.arch.federation.pulsar import PulsarFederation - route_table = {} for k, v in federation.metadata.route_table.route.items(): route_table.update({k: v.dict()}) if (default := federation.metadata.route_table.default) is not None: route_table.update({"default": default.dict()}) - return PulsarFederation.from_conf( - federation_session_id=federation.metadata.federation_id, + return builder.build_pulsar( computing_session=computing, - party=federation.metadata.parties.local.tuple(), - parties=[p.tuple() for p in federation.metadata.parties.parties], - route_table=route_table, - mode=federation.metadata.pulsar_config.mode, host=federation.metadata.pulsar_config.host, port=federation.metadata.pulsar_config.port, - mng_port=federation.metadata.pulsar_config.mng_port, - base_user=federation.metadata.pulsar_config.user, - base_password=federation.metadata.pulsar_config.password, - max_message_size=federation.metadata.pulsar_config.max_message_size, - topic_ttl=federation.metadata.pulsar_config.topic_ttl, - cluster=federation.metadata.pulsar_config.cluster, - tenant=federation.metadata.pulsar_config.tenant, - pulsar_run=federation.metadata.pulsar_run, - connection=federation.metadata.connection, + options=dict( + route_table=route_table, + mode=federation.metadata.pulsar_config.mode, + host=federation.metadata.pulsar_config.host, + port=federation.metadata.pulsar_config.port, + mng_port=federation.metadata.pulsar_config.mng_port, + base_user=federation.metadata.pulsar_config.user, + base_password=federation.metadata.pulsar_config.password, + max_message_size=federation.metadata.pulsar_config.max_message_size, + topic_ttl=federation.metadata.pulsar_config.topic_ttl, + cluster=federation.metadata.pulsar_config.cluster, + tenant=federation.metadata.pulsar_config.tenant, + pulsar_run=federation.metadata.pulsar_run, + connection=federation.metadata.connection, + ), ) - if isinstance(federation, OSXFederationSpec): - from fate.arch.federation.osx import OSXFederation - - return OSXFederation.from_conf( - federation_session_id=federation.metadata.federation_id, - computing_session=computing, - party=federation.metadata.parties.local.tuple(), - parties=[p.tuple() for p in federation.metadata.parties.parties], - host=federation.metadata.osx_config.host, - port=federation.metadata.osx_config.port, - max_message_size=federation.metadata.osx_config.max_message_size, - ) # TODO: load from plugin raise ValueError(f"conf.federation={federation} not support") diff --git a/python/fate/components/core/spec/federation.py b/python/fate/components/core/spec/federation.py index 0eee05a157..ad4bbf1867 100644 --- a/python/fate/components/core/spec/federation.py +++ b/python/fate/components/core/spec/federation.py @@ -125,6 +125,7 @@ class MetadataSpec(pydantic.BaseModel): class OSXConfig(pydantic.BaseModel): host: str port: int + mode: str = "message_queue" max_message_size: Optional[int] = None federation_id: str From 898ecc4b50485c3d8e3a008f0462d5f068a5021d Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 6 Dec 2023 16:06:15 +0800 Subject: [PATCH 10/42] fix split error Signed-off-by: sagewe --- .../java/org/fedai/osx/broker/queue/TransferQueueManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java index 0faa2122a4..996e0fec31 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/queue/TransferQueueManager.java @@ -199,7 +199,7 @@ public CreateQueueResult createNewQueue(String sessionId, String topic, boolean AbstractQueue queue = this.getQueue(sessionId, topic); if (queue != null) { createQueueResult.setQueue(queue); - String[] elements = MetaInfo.INSTANCE_ID.split(":"); + String[] elements = MetaInfo.INSTANCE_ID.split("_"); createQueueResult.setPort(Integer.parseInt(elements[1])); createQueueResult.setRedirectIp(elements[0]); return createQueueResult; From f927fd5585a05e60041a2cc527d187f8b8c0abcf Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 13:04:50 +0800 Subject: [PATCH 11/42] add host and port params to eggroll Signed-off-by: sagewe --- python/fate/arch/computing/eggroll/_csession.py | 11 ++++++++++- python/fate/components/core/_load_computing.py | 9 ++++++++- python/fate/components/core/spec/computing.py | 5 +++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/eggroll/_csession.py index ab64350d50..0e28b39f50 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/eggroll/_csession.py @@ -32,12 +32,21 @@ class CSession(KVTableContext): def __init__( - self, session_id, options: dict = None, config=None, config_options=None, config_properties_file=None + self, + session_id, + host: str = None, + port: int = None, + options: dict = None, + config=None, + config_options=None, + config_properties_file=None, ): if options is None: options = {} self._eggroll_session = session_init( session_id=session_id, + host=host, + port=port, options=options, config=config, config_options=config_options, diff --git a/python/fate/components/core/_load_computing.py b/python/fate/components/core/_load_computing.py index 40bfebcd61..b9d929394a 100644 --- a/python/fate/components/core/_load_computing.py +++ b/python/fate/components/core/_load_computing.py @@ -31,7 +31,14 @@ def load_computing(computing, logger_config=None): if isinstance(computing, EggrollComputingSpec): from fate.arch.computing.eggroll import CSession - return CSession(computing.metadata.computing_id, options=computing.metadata.options) + return CSession( + computing.metadata.computing_id, + host=computing.metadata.host, + port=computing.metadata.port, + config_options=computing.metadata.config_options, + config_properties_file=computing.metadata.config_properties_file, + options=computing.metadata.options, + ) if isinstance(computing, SparkComputingSpec): from fate.arch.computing.spark import CSession diff --git a/python/fate/components/core/spec/computing.py b/python/fate/components/core/spec/computing.py index aec4ebc860..d1fdeab12e 100644 --- a/python/fate/components/core/spec/computing.py +++ b/python/fate/components/core/spec/computing.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing from typing import Literal import pydantic @@ -29,6 +30,10 @@ class MetadataSpec(pydantic.BaseModel): class EggrollComputingSpec(pydantic.BaseModel): class MetadataSpec(pydantic.BaseModel): computing_id: str + host: typing.Optional[str] = None + port: typing.Optional[int] = None + config_options: typing.Optional[dict] = None + config_properties_file: typing.Optional[str] = None options: dict = {} type: Literal["eggroll"] From d7fd80791b40ddeb64780d38263ce3eac6f47086 Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 14:33:13 +0800 Subject: [PATCH 12/42] fix launcher Signed-off-by: sagewe --- python/fate/arch/launchers/context_helper.py | 47 +++++++++++-------- .../arch/launchers/multiprocess_launcher.py | 11 ++--- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index 9be2dab636..ae60b4ed1e 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -6,8 +6,7 @@ @dataclass -class LauncherStandaloneContextArgs: - federation_session_id: str = field() +class LauncherLocalContextArgs: parties: List[str] = field() rank: int = field() csession_id: str = field(default=None) @@ -15,14 +14,13 @@ class LauncherStandaloneContextArgs: @dataclass -class LauncherDistributedContextArgs: - federation_session_id: str = field() +class LauncherClusterContextArgs: parties: List[str] = field() rank: int = field() config_properties: str = field() csession_id: str = field(default=None) - host: str = field(default="127.0.0.1") - port: int = field(default=9377) + federation_address: str = field(default="127.0.0.1:9377") + cluster_address: str = field(default="127.0.0.1:4670") federation_mode: str = field(default="message_queue") @@ -31,34 +29,35 @@ class LauncherContextArguments: context_type: str = field(default="standalone") -def init_context(): +def init_context(computing_session_id: str, federation_session_id: str): args = HfArgumentParser(LauncherContextArguments).parse_known_args()[0] if args.context_type == "local": - return init_standalone_context() + return init_local_context(computing_session_id, federation_session_id) elif args.context_type == "cluster": - return init_distributed_context() + return init_cluster_context(computing_session_id, federation_session_id) else: raise ValueError(f"unknown context type: {args.context_type}") -def init_standalone_context(): +def init_local_context(computing_session_id: str, federation_session_id: str): from fate.arch.utils.paths import get_base_dir from fate.arch.computing.standalone import CSession from fate.arch.federation import FederationBuilder from fate.arch.context import Context - args = HfArgumentParser(LauncherStandaloneContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ + args = HfArgumentParser(LauncherLocalContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ 0 ] data_dir = args.data_dir if not data_dir: data_dir = os.path.join(get_base_dir(), "data") - computing_session = CSession(session_id=args.csession_id, data_dir=data_dir) + computing_session = CSession(session_id=computing_session_id, data_dir=data_dir) + parties = get_parties(args.parties) party = parties[args.rank] federation_session = FederationBuilder( - federation_id=args.federation_session_id, party=party, parties=parties + federation_id=federation_session_id, party=party, parties=parties ).build_standalone( computing_session, ) @@ -66,27 +65,35 @@ def init_standalone_context(): return context -def init_distributed_context(): +def init_cluster_context(computing_session_id: str, federation_session_id: str): from fate.arch.federation import FederationBuilder, FederationMode from fate.arch.computing.eggroll import CSession from fate.arch.context import Context - args = HfArgumentParser(LauncherDistributedContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ + args = HfArgumentParser(LauncherClusterContextArgs).parse_args_into_dataclasses(return_remaining_strings=True)[ 0 ] + + cluster_host, cluster_port = args.cluster_address.split(":") + computing_session = CSession( + session_id=computing_session_id, + host=cluster_host.strip(), + port=int(cluster_port.strip()), + ) + parties = get_parties(args.parties) party = parties[args.rank] - computing_session = CSession(session_id=args.csession_id, config_properties_file=args.config_properties) federation_mode = FederationMode.from_str(args.federation_mode) - + federation_host, federation_port = args.federation_address.split(":") federation_session = FederationBuilder( - federation_id=args.federation_session_id, party=party, parties=parties + federation_id=federation_session_id, party=party, parties=parties ).build_osx( computing_session=computing_session, - host=args.host, - port=args.port, + host=federation_host.strip(), + port=int(federation_port.strip()), mode=federation_mode, ) + context = Context(computing=computing_session, federation=federation_session) return context diff --git a/python/fate/arch/launchers/multiprocess_launcher.py b/python/fate/arch/launchers/multiprocess_launcher.py index 7d9d0e9eec..aaac027869 100644 --- a/python/fate/arch/launchers/multiprocess_launcher.py +++ b/python/fate/arch/launchers/multiprocess_launcher.py @@ -23,13 +23,13 @@ import logging import multiprocessing import signal +import sys import time import uuid from argparse import Namespace from dataclasses import dataclass, field from multiprocessing import Queue, Event from typing import List -import sys import rich import rich.console @@ -80,7 +80,6 @@ def start(self, f, carrier=None): safe_to_exit = self.safe_to_exit width = self.console.width argv = sys.argv.copy() - argv.extend(["--federation_session_id", self.federation_session_id]) argv.extend(["--rank", str(rank)]) process = multiprocessing.Process( target=self.__class__._run_process, @@ -90,6 +89,7 @@ def start(self, f, carrier=None): output_or_exception_q, safe_to_exit, width, + self.federation_session_id, argv, f, ), @@ -123,6 +123,7 @@ def _run_process( output_or_exception_q: Queue, safe_to_exit: Event, width, + federation_session_id, argv, f, ): @@ -137,8 +138,7 @@ def _run_process( raise ValueError(f"rank {args.rank} is out of range {len(args.parties)}") parties = args.get_parties() party = parties[args.rank] - csession_id = f"{args.federation_session_id}_{party[0]}_{party[1]}" - argv.extend(["--csession_id", csession_id]) + csession_id = f"{federation_session_id}_{party[0]}_{party[1]}" # set up logging set_up_logging(args.rank, args.log_level) @@ -149,7 +149,7 @@ def _run_process( tracer = trace.get_tracer(__name__) with tracer.start_as_current_span(name=csession_id, context=trace.extract_carrier(carrier)) as span: - ctx = init_context() + ctx = init_context(computing_session_id=csession_id, federation_session_id=federation_session_id) try: profile_start() @@ -238,7 +238,6 @@ class LauncherArguments: @dataclass class LauncherProcessArguments: - federation_session_id: str = field() log_level: str = field() rank: int = field() parties: List[str] = field(metadata={"required": True}) From a272a6b6d249131b04ce2c9aca0f0a99f1747b3e Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 7 Dec 2023 16:34:11 +0800 Subject: [PATCH 13/42] add sshe lr & sshe linr (#5227) Signed-off-by: Yu Wu --- .../sshe_linr/sshe_linr_testsuite.yaml | 42 +++ examples/pipeline/sshe_linr/test_linr.py | 88 ++++++ examples/pipeline/sshe_linr/test_linr_cv.py | 67 +++++ .../sshe_linr/test_linr_warm_start.py | 100 ++++++ .../pipeline/sshe_lr/sshe_lr_testsuite.yaml | 80 +++++ examples/pipeline/sshe_lr/test_lr.py | 92 ++++++ examples/pipeline/sshe_lr/test_lr_cv.py | 67 +++++ .../pipeline/sshe_lr/test_lr_multi_class.py | 94 ++++++ .../sshe_lr/test_lr_predict_w_torch.py | 100 ++++++ examples/pipeline/sshe_lr/test_lr_validate.py | 80 +++++ .../pipeline/sshe_lr/test_lr_warm_start.py | 100 ++++++ launchers/sshe_linr_launcher.py | 15 +- launchers/sshe_lr_launcher.py | 11 +- python/fate/arch/protocol/mpc/mpc.py | 15 + .../arch/protocol/mpc/nn/sshe/linr_layer.py | 5 +- python/fate/components/components/__init__.py | 12 + .../fate/components/components/sshe_linr.py | 258 ++++++++++++++++ python/fate/components/components/sshe_lr.py | 23 +- python/fate/ml/glm/hetero/sshe/__init__.py | 18 ++ python/fate/ml/glm/hetero/sshe/sshe_linr.py | 284 ++++++++++++++++++ .../ml/{mpc => glm/hetero/sshe}/sshe_lr.py | 125 +++++--- python/fate/ml/mpc/sshe_linr.py | 43 --- python/fate/ml/utils/_convergence.py | 10 +- python/fate/ml/utils/_model_param.py | 22 ++ 24 files changed, 1649 insertions(+), 102 deletions(-) create mode 100644 examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml create mode 100644 examples/pipeline/sshe_linr/test_linr.py create mode 100644 examples/pipeline/sshe_linr/test_linr_cv.py create mode 100644 examples/pipeline/sshe_linr/test_linr_warm_start.py create mode 100644 examples/pipeline/sshe_lr/sshe_lr_testsuite.yaml create mode 100644 examples/pipeline/sshe_lr/test_lr.py create mode 100644 examples/pipeline/sshe_lr/test_lr_cv.py create mode 100644 examples/pipeline/sshe_lr/test_lr_multi_class.py create mode 100644 examples/pipeline/sshe_lr/test_lr_predict_w_torch.py create mode 100644 examples/pipeline/sshe_lr/test_lr_validate.py create mode 100644 examples/pipeline/sshe_lr/test_lr_warm_start.py create mode 100644 python/fate/components/components/sshe_linr.py create mode 100644 python/fate/ml/glm/hetero/sshe/__init__.py create mode 100644 python/fate/ml/glm/hetero/sshe/sshe_linr.py rename python/fate/ml/{mpc => glm/hetero/sshe}/sshe_lr.py (74%) delete mode 100644 python/fate/ml/mpc/sshe_linr.py diff --git a/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml b/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml new file mode 100644 index 0000000000..14366bff98 --- /dev/null +++ b/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml @@ -0,0 +1,42 @@ +data: + - file: examples/data/motor_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: float64 + label_name: motor_speed + match_id_name: idx + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/motor_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: idx + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: motor_hetero_host + namespace: experiment + role: host_0 +tasks: + normal-linr: + script: test_linr.py + linr-cv: + script: test_linr_cv.py + linr-warm-start: + script: test_linr_warm_start.py diff --git a/examples/pipeline/sshe_linr/test_linr.py b/examples/pipeline/sshe_linr/test_linr.py new file mode 100644 index 0000000000..f990f17cce --- /dev/null +++ b/examples/pipeline/sshe_linr/test_linr.py @@ -0,0 +1,88 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = SSHELinR("linr_0", + epochs=10, + batch_size=100, + init_param={"fit_intercept": True}, + train_data=psi_0.outputs["output_data"], + reveal_every_epoch=False, + early_stop="diff", + reveal_loss_freq=1, + learning_rate=0.1) + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, linr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.task_setting( + input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].task_setting( + input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_linr/test_linr_cv.py b/examples/pipeline/sshe_linr/test_linr_cv.py new file mode 100644 index 0000000000..5b61fef983 --- /dev/null +++ b/examples/pipeline/sshe_linr/test_linr_cv.py @@ -0,0 +1,67 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import SSHELinR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = SSHELinR("linr_0", + epochs=10, + batch_size=None, + learning_rate=0.05, + init_param={"fit_intercept": True}, + cv_data=psi_0.outputs["output_data"], + cv_param={"n_splits": 3}, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_linr/test_linr_warm_start.py b/examples/pipeline/sshe_linr/test_linr_warm_start.py new file mode 100644 index 0000000000..0e9a5a302e --- /dev/null +++ b/examples/pipeline/sshe_linr/test_linr_warm_start.py @@ -0,0 +1,100 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELinR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + linr_0 = SSHELinR("linr_0", + epochs=4, + batch_size=None, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate=0.05, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + linr_1 = SSHELinR("linr_1", train_data=psi_0.outputs["output_data"], + warm_start_model=linr_0.outputs["output_model"], + epochs=2, + batch_size=None, + learning_rate=0.05, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + + linr_2 = SSHELinR("linr_2", epochs=6, + batch_size=None, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + learning_rate=0.05, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=[linr_1.outputs["train_output_data"], linr_2.outputs["train_output_data"]]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + pipeline.add_task(linr_1) + pipeline.add_task(linr_2) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + # print(f"linr_1 model: {pipeline.get_task_info('linr_1').get_output_model()}") + # print(f"train linr_1 data: {pipeline.get_task_info('linr_1').get_output_data()}") + + # print(f"linr_2 model: {pipeline.get_task_info('linr_2').get_output_model()}") + # print(f"train linr_2 data: {pipeline.get_task_info('linr_2').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/sshe_lr_testsuite.yaml b/examples/pipeline/sshe_lr/sshe_lr_testsuite.yaml new file mode 100644 index 0000000000..ef5899db6a --- /dev/null +++ b/examples/pipeline/sshe_lr/sshe_lr_testsuite.yaml @@ -0,0 +1,80 @@ +data: + - file: examples/data/breast_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + label_type: int64 + label_name: y + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/breast_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + partitions: 4 + head: true + extend_sid: true + table_name: breast_hetero_host + namespace: experiment + role: host_0 + - file: examples/data/vehicle_scale_hetero_guest.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + label_type: int64 + label_name: y + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_hetero_guest + namespace: experiment + role: guest_0 + - file: examples/data/vehicle_scale_hetero_host.csv + meta: + delimiter: "," + dtype: float64 + input_format: dense + match_id_name: id + match_id_range: 0 + tag_value_delimiter: ":" + tag_with_value: false + weight_type: float64 + head: true + partitions: 4 + extend_sid: true + table_name: vehicle_scale_hetero_host + namespace: experiment + role: host_0 +tasks: + normal-lr: + script: test_lr.py + lr-cv: + script: test_lr_cv.py + lr-validate: + script: test_lr_validate.py + lr-warm-start: + script: test_lr_warm_start.py + lr-multi-class: + script: test_lr_multi_class.py diff --git a/examples/pipeline/sshe_lr/test_lr.py b/examples/pipeline/sshe_lr/test_lr.py new file mode 100644 index 0000000000..4a5f47263f --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr.py @@ -0,0 +1,92 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = SSHELR("lr_0", + learning_rate=0.05, + epochs=10, + batch_size=300, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, ) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + pipeline.deploy([psi_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.task_setting( + input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].task_setting( + input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/test_lr_cv.py b/examples/pipeline/sshe_lr/test_lr_cv.py new file mode 100644 index 0000000000..fabbd10b19 --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr_cv.py @@ -0,0 +1,67 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = SSHELR("lr_0", + learning_rate=0.15, + epochs=2, + batch_size=None, + init_param={"fit_intercept": True}, + cv_data=psi_0.outputs["output_data"], + cv_param={"n_splits": 3}, + reveal_every_epoch=False, + early_stop="diff", + reveal_loss_freq=1, + ) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/test_lr_multi_class.py b/examples/pipeline/sshe_lr/test_lr_multi_class.py new file mode 100644 index 0000000000..5a64db054f --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr_multi_class.py @@ -0,0 +1,94 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = SSHELR("lr_0", + learning_rate=0.15, + epochs=10, + batch_size=None, + reveal_every_epoch=True, + early_stop="weight_diff", + reveal_loss_freq=1, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=psi_0.outputs["output_data"]) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="multi", + predict_column_name='predict_result', + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, lr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.task_setting( + input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].task_setting( + input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + # print(f"predict lr_0 data: {pipeline.get_task_info('lr_0').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py new file mode 100644 index 0000000000..eb8a67382f --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py @@ -0,0 +1,100 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +class LogisticRegression(torch.nn.Module): + def __init__(self, coefficients): + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(coefficients.shape[1], 1) + + def forward(self, x): + y_pred = torch.sigmoid(self.linear(x)) + return y_pred + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = SSHELR("lr_0", + epochs=10, + batch_size=300, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + ) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + pipeline.fit() + + lr_model = pipeline.get_task_info('lr_0').get_output_model() + param = lr_model['output_model']['data']['estimator']['param'] + dtype = getattr(torch, param['dtype']) + coef = torch.transpose(torch.tensor(param['coef_'], dtype=dtype), 0, 1) + intercept = torch.tensor(param["intercept_"], dtype=dtype) + + import pandas as pd + + input_data = pd.read_csv("../../data/breast_hetero_guest.csv", index_col="id") + input_data.drop(['y'], axis=1, inplace=True) + input_data = torch.tensor(input_data.values, dtype=dtype) + + pytorch_model = LogisticRegression(coef) + with torch.no_grad(): + pytorch_model.linear.weight.copy_(coef) + pytorch_model.linear.bias.copy_(intercept) + predict_result = pytorch_model(input_data) + print(f"predictions shape: {predict_result.shape}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/test_lr_validate.py b/examples/pipeline/sshe_lr/test_lr_validate.py new file mode 100644 index 0000000000..e1bcc5d154 --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr_validate.py @@ -0,0 +1,80 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI, DataSplit +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + data_split_0 = DataSplit("data_split_0", + train_size=0.8, + validate_size=0.2, + input_data=psi_0.outputs["output_data"]) + lr_0 = SSHELR("lr_0", + epochs=10, + batch_size=300, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + init_param={"fit_intercept": True, "method": "random_uniform"}, + train_data=data_split_0.outputs["train_output_data"], + validate_data=data_split_0.outputs["validate_output_data"], + ) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(data_split_0) + pipeline.add_task(lr_0) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/examples/pipeline/sshe_lr/test_lr_warm_start.py b/examples/pipeline/sshe_lr/test_lr_warm_start.py new file mode 100644 index 0000000000..55cdce9ebf --- /dev/null +++ b/examples/pipeline/sshe_lr/test_lr_warm_start.py @@ -0,0 +1,100 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", + namespace=f"experiment{namespace}")) + lr_0 = SSHELR("lr_0", + epochs=4, + batch_size=None, + learning_rate=0.05, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + lr_1 = SSHELR("lr_1", train_data=psi_0.outputs["output_data"], + warm_start_model=lr_0.outputs["output_model"], + epochs=2, + batch_size=None, + learning_rate=0.05, + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + + lr_2 = SSHELR("lr_2", epochs=6, + batch_size=None, + learning_rate=0.05, + init_param={"fit_intercept": True, "method": "zeros"}, + train_data=psi_0.outputs["output_data"], + reveal_every_epoch=True, + early_stop="diff", + reveal_loss_freq=1, + ) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="binary", + input_data=[lr_1.outputs["train_output_data"], lr_2.outputs["train_output_data"]]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(lr_2) + pipeline.add_task(evaluation_0) + + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + # print(f"lr_1 model: {pipeline.get_task_info('lr_1').get_output_model()}") + # print(f"train lr_1 data: {pipeline.get_task_info('lr_1').get_output_data()}") + + # print(f"lr_2 model: {pipeline.get_task_info('lr_2').get_output_model()}") + # print(f"train lr_2 data: {pipeline.get_task_info('lr_2').get_output_data()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) diff --git a/launchers/sshe_linr_launcher.py b/launchers/sshe_linr_launcher.py index 59fbd980bf..7346a023fc 100644 --- a/launchers/sshe_linr_launcher.py +++ b/launchers/sshe_linr_launcher.py @@ -19,18 +19,20 @@ class SSHEArguments: def run_sshe_linr(ctx: "Context"): - from fate.ml.mpc.sshe_linr import SSHELinearRegression + from fate.ml.glm.hetero.sshe import SSHELinearRegression from fate.arch import dataframe ctx.mpc.init() args, _ = HfArgumentParser(SSHEArguments).parse_args_into_dataclasses(return_remaining_strings=True) - inst = SSHELinearRegression(args.lr) + inst = SSHELinearRegression(epochs=5, batch_size=300, tol=0.01, early_stop='diff', learning_rate=0.15, + init_param={"method": "random_uniform", "fit_intercept": True, "random_state": 1}, + reveal_every_epoch=False, reveal_loss_freq=2, threshold=0.5) if ctx.is_on_guest: kwargs = { "sample_id_name": None, - "match_id_name": "id", + "match_id_name": "idx", "delimiter": ",", - "label_name": "y", + "label_name": "motor_speed", "label_type": "float32", "dtype": "float32", } @@ -38,12 +40,13 @@ def run_sshe_linr(ctx: "Context"): else: kwargs = { "sample_id_name": None, - "match_id_name": "id", + "match_id_name": "idx", "delimiter": ",", "dtype": "float32", } input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data) - inst.fit(ctx, input_data=input_data) + inst.fit(ctx, train_data=input_data) + print(f"model: {inst.get_model()}") if __name__ == "__main__": diff --git a/launchers/sshe_lr_launcher.py b/launchers/sshe_lr_launcher.py index 911d4a28dc..202e2839f3 100644 --- a/launchers/sshe_lr_launcher.py +++ b/launchers/sshe_lr_launcher.py @@ -19,19 +19,21 @@ class SSHEArguments: def run_sshe_lr(ctx: "Context"): - from fate.ml.mpc.sshe_lr import SSHELogisticRegression + from fate.ml.glm.hetero.sshe import SSHELogisticRegression from fate.arch import dataframe ctx.mpc.init() args, _ = HfArgumentParser(SSHEArguments).parse_args_into_dataclasses(return_remaining_strings=True) - inst = SSHELogisticRegression(args.lr) + inst = SSHELogisticRegression(epochs=5, batch_size=300, tol=0.01, early_stop='diff', learning_rate=0.15, + init_param={"method": "random_uniform", "fit_intercept": True, "random_state": 1}, + reveal_every_epoch=False, reveal_loss_freq=2, threshold=0.5) if ctx.is_on_guest: kwargs = { "sample_id_name": None, "match_id_name": "id", "delimiter": ",", "label_name": "y", - "label_type": "float32", + "label_type": "int32", "dtype": "float32", } input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.guest_data) @@ -43,7 +45,8 @@ def run_sshe_lr(ctx: "Context"): "dtype": "float32", } input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data) - inst.fit(ctx, input_data=input_data) + inst.fit(ctx, train_data=input_data) + print(inst.get_model()) if __name__ == "__main__": diff --git a/python/fate/arch/protocol/mpc/mpc.py b/python/fate/arch/protocol/mpc/mpc.py index 997848b16e..f69b7428ef 100644 --- a/python/fate/arch/protocol/mpc/mpc.py +++ b/python/fate/arch/protocol/mpc/mpc.py @@ -5,6 +5,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from . import communicator as comm diff --git a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py index 31404c59db..e9650a0700 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py @@ -25,7 +25,8 @@ def __init__( self.ctx = ctx self.rank_a = rank_a self.rank_b = rank_b - self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_aggregator_layer") + self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], + f"{ctx.namespace.federation_tag}.sshe_aggregator_layer") if sync_shape: ctx.mpc.option_assert(in_features_a is not None, "in_features_a must be specified", dst=rank_a) @@ -146,7 +147,7 @@ def __call__(self, dz): class SSHELinearRegressionLossLayer: def __init__(self, ctx: Context, rank_a, rank_b): self.ctx = ctx - self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_loss_layer") + self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], f"{ctx.namespace.federation_tag}.sshe_loss_layer") self.rank_a = rank_a self.rank_b = rank_b self.phe_cipher = ctx.cipher.phe.setup() diff --git a/python/fate/components/components/__init__.py b/python/fate/components/components/__init__.py index 346bdc8929..bfcc6e0e63 100644 --- a/python/fate/components/components/__init__.py +++ b/python/fate/components/components/__init__.py @@ -150,6 +150,18 @@ def data_split(self): return data_split + @_lazy_cpn + def sshe_lr(self): + from .sshe_lr import sshe_lr + + return sshe_lr + + @_lazy_cpn + def sshe_linr(self): + from .sshe_linr import sshe_linr + + return sshe_linr + @_lazy_cpn def toy_example(self): from .toy_example import toy_example diff --git a/python/fate/components/components/sshe_linr.py b/python/fate/components/components/sshe_linr.py new file mode 100644 index 0000000000..eeb546f887 --- /dev/null +++ b/python/fate/components/components/sshe_linr.py @@ -0,0 +1,258 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from fate.arch import Context +from fate.arch.dataframe import DataFrame +from fate.components.components.utils import consts, tools +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.glm.hetero.sshe import SSHELinearRegression + +logger = logging.getLogger(__name__) + + +@cpn.component(roles=[GUEST, HOST], provider="fate") +def sshe_linr(ctx, role): + ... + + +@sshe_linr.train() +def train( + ctx: Context, + role: Role, + train_data: cpn.dataframe_input(roles=[GUEST, HOST]), + validate_data: cpn.dataframe_input(roles=[GUEST, HOST], optional=True), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs}", + ), + learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), + reveal_every_epoch: cpn.parameter(type=bool, default=False, + desc="whether reveal encrypted result every epoch, " + "if False, only reveal at the end of training"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, + desc="rounds to reveal training loss, " + "only effective if `early_stop` is 'loss'"), + train_output_data: cpn.dataframe_output(roles=[GUEST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True)): + logger.info(f"enter sshe linr train") + # temp code start + init_param = init_param.dict() + ctx.mpc.init() + + train_model( + ctx, + role, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + learning_rate, + tol, + early_stop, + init_param, + reveal_every_epoch, + reveal_loss_freq, + threshold, + warm_start_model + ) + + +@sshe_linr.predict() +def predict( + ctx, + role: Role, + # threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5), + test_data: cpn.dataframe_input(roles=[GUEST, HOST]), + input_model: cpn.json_model_input(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST]), +): + ctx.mpc.init() + predict_from_model(ctx, role, input_model, test_data, test_output_data) + + +@sshe_linr.cross_validation() +def cross_validation( + ctx: Context, + role: Role, + cv_data: cpn.dataframe_input(roles=[GUEST, HOST]), + epochs: cpn.parameter(type=params.conint(gt=0), default=20, desc="max iteration num"), + batch_size: cpn.parameter( + type=params.conint(ge=10), + default=None, desc="batch size, None means full batch, otherwise should be no less than 10, default None" + ), + tol: cpn.parameter(type=params.confloat(ge=0), default=1e-4), + early_stop: cpn.parameter( + type=params.string_choice(["weight_diff", "diff", "abs"]), + default="diff", + desc="early stopping criterion, choose from {weight_diff, diff, abs}", + ), + learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), + init_param: cpn.parameter( + type=params.init_param(), + default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), + desc="Model param init setting.", + ), + threshold: cpn.parameter( + type=params.confloat(ge=0.0, le=1.0), default=0.5, desc="predict threshold for binary data" + ), + reveal_every_epoch: cpn.parameter(type=bool, default=False, + desc="whether reveal encrypted result every epoch, " + "if False, only reveal at the end of training"), + reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, + desc="rounds to reveal training loss, " + "only effective if `early_stop` is 'loss'"), + cv_param: cpn.parameter(type=params.cv_param(), + default=params.CVParam(n_splits=5, shuffle=False, random_state=None), + desc="cross validation param"), + metrics: cpn.parameter(type=params.metrics_param(), default=["auc"]), + output_cv_data: cpn.parameter(type=bool, default=True, desc="whether output prediction result per cv fold"), + cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), +): + init_param = init_param.dict() + ctx.mpc.init() + + from fate.arch.dataframe import KFold + kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) + i = 0 + for fold_ctx, (train_data, validate_data) in ctx.on_cross_validations.ctxs_zip(kf.split(cv_data.read())): + logger.info(f"enter fold {i}") + module = SSHELinearRegression( + epochs=epochs, + batch_size=batch_size, + learning_rate=learning_rate, + tol=tol, + early_stop=early_stop, + init_param=init_param, + threshold=threshold, + reveal_every_epoch=reveal_every_epoch, + reveal_loss_freq=reveal_loss_freq + ) + module.fit(fold_ctx, train_data, validate_data) + if output_cv_data: + if role.is_guest: + sub_ctx = fold_ctx.sub_ctx("predict_train") + predict_df = module.predict(sub_ctx, train_data) + train_predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + predict_df = module.predict(sub_ctx, validate_data) + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([train_predict_result, validate_predict_result]) + next(cv_output_datas).write(df=predict_result) + elif role.is_host: + sub_ctx = fold_ctx.sub_ctx("predict_train") + module.predict(sub_ctx, train_data) + sub_ctx = fold_ctx.sub_ctx("predict_validate") + module.predict(sub_ctx, validate_data) + i += 1 + + +def train_model( + ctx, + role, + train_data, + validate_data, + train_output_data, + output_model, + epochs, + batch_size, + learning_rate, + tol, + early_stop, + init_param, + reveal_every_epoch, + reveal_loss_freq, + threshold, + input_model +): + if input_model is not None: + logger.info(f"warm start model provided") + model = input_model.read() + module = SSHELinearRegression.from_model(model) + module.set_epochs(epochs) + module.set_batch_size(batch_size) + + else: + module = SSHELinearRegression( + epochs=epochs, + batch_size=batch_size, + tol=tol, + early_stop=early_stop, + learning_rate=learning_rate, + init_param=init_param, + threshold=threshold, + reveal_every_epoch=reveal_every_epoch, + reveal_loss_freq=reveal_loss_freq + ) + # optimizer = optimizer_factory(optimizer_param) + logger.info(f"sshe linr guest start train") + sub_ctx = ctx.sub_ctx("train") + train_data = train_data.read() + + if validate_data is not None: + logger.info(f"validate data provided") + validate_data = validate_data.read() + + module.fit(sub_ctx, train_data, validate_data) + model = module.get_model() + output_model.write(model, metadata={}) + + sub_ctx = ctx.sub_ctx("predict") + + predict_df = module.predict(sub_ctx, train_data) + + if role.is_guest: + predict_result = tools.add_dataset_type(predict_df, consts.TRAIN_SET) + if validate_data is not None: + sub_ctx = ctx.sub_ctx("validate_predict") + predict_df = module.predict(sub_ctx, validate_data) + if role.is_guest: + validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) + predict_result = DataFrame.vstack([predict_result, validate_predict_result]) + if role.is_guest: + train_output_data.write(predict_result) + + +def predict_from_model(ctx, role, input_model, test_data, test_output_data): + logger.info(f"sshe linr guest start predict") + sub_ctx = ctx.sub_ctx("predict") + model = input_model.read() + module = SSHELinearRegression.from_model(model) + # if module.threshold != 0.5: + # module.threshold = threshold + test_data = test_data.read() + predict_df = module.predict(sub_ctx, test_data) + if role.is_guest: + predict_result = tools.add_dataset_type(predict_df, consts.TEST_SET) + test_output_data.write(predict_result) diff --git a/python/fate/components/components/sshe_lr.py b/python/fate/components/components/sshe_lr.py index 0cfb4bc5dd..f7cc04a5a7 100644 --- a/python/fate/components/components/sshe_lr.py +++ b/python/fate/components/components/sshe_lr.py @@ -18,13 +18,13 @@ from fate.arch import Context from fate.arch.dataframe import DataFrame from fate.components.components.utils import consts, tools -from fate.components.core import ARBITER, GUEST, HOST, Role, cpn, params -from fate.ml.mpc.sshe_lr import SSHELogisticRegression +from fate.components.core import GUEST, HOST, Role, cpn, params +from fate.ml.glm.hetero.sshe import SSHELogisticRegression logger = logging.getLogger(__name__) -@cpn.component(roles=[GUEST, HOST, ARBITER], provider="fate") +@cpn.component(roles=[GUEST, HOST], provider="fate") def sshe_lr(ctx, role): ... @@ -61,12 +61,12 @@ def train( reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, desc="rounds to reveal training loss, " "only effective if `early_stop` is 'loss'"), - train_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), - output_model: cpn.json_model_output(roles=[GUEST, HOST, ARBITER]), - warm_start_model: cpn.json_model_input(roles=[GUEST, HOST, ARBITER], optional=True)): + train_output_data: cpn.dataframe_output(roles=[GUEST]), + output_model: cpn.json_model_output(roles=[GUEST, HOST]), + warm_start_model: cpn.json_model_input(roles=[GUEST, HOST], optional=True)): logger.info(f"enter sshe lr train") - # temp code start init_param = init_param.dict() + ctx.mpc.init() train_model( ctx, @@ -95,8 +95,9 @@ def predict( # threshold: cpn.parameter(type=params.confloat(ge=0.0, le=1.0), default=0.5), test_data: cpn.dataframe_input(roles=[GUEST, HOST]), input_model: cpn.json_model_input(roles=[GUEST, HOST]), - test_output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + test_output_data: cpn.dataframe_output(roles=[GUEST]), ): + ctx.mpc.init() predict_from_model(ctx, role, input_model, test_data, test_output_data) @@ -139,6 +140,7 @@ def cross_validation( cv_output_datas: cpn.dataframe_outputs(roles=[GUEST, HOST], optional=True), ): init_param = init_param.dict() + ctx.mpc.init() from fate.arch.dataframe import KFold kf = KFold(ctx, role=role, n_splits=cv_param.n_splits, shuffle=cv_param.shuffle, random_state=cv_param.random_state) @@ -234,10 +236,11 @@ def train_model( if validate_data is not None: sub_ctx = ctx.sub_ctx("validate_predict") predict_df = module.predict(sub_ctx, validate_data) - if ctx.is_guest: + if role.is_guest: validate_predict_result = tools.add_dataset_type(predict_df, consts.VALIDATE_SET) predict_result = DataFrame.vstack([predict_result, validate_predict_result]) - train_output_data.write(predict_result) + if role.is_guest: + train_output_data.write(predict_result) def predict_from_model(ctx, role, input_model, test_data, test_output_data): diff --git a/python/fate/ml/glm/hetero/sshe/__init__.py b/python/fate/ml/glm/hetero/sshe/__init__.py new file mode 100644 index 0000000000..ca4546420e --- /dev/null +++ b/python/fate/ml/glm/hetero/sshe/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .sshe_linr import SSHELinearRegression +from .sshe_lr import SSHELogisticRegression diff --git a/python/fate/ml/glm/hetero/sshe/sshe_linr.py b/python/fate/ml/glm/hetero/sshe/sshe_linr.py new file mode 100644 index 0000000000..5303bf0d8f --- /dev/null +++ b/python/fate/ml/glm/hetero/sshe/sshe_linr.py @@ -0,0 +1,284 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch + +from fate.arch import Context, dataframe +from fate.arch.dataframe import DataFrame +from fate.arch.protocol.mpc.nn.sshe.linr_layer import ( + SSHELinearRegressionLayer, + SSHELinearRegressionLossLayer, + SSHEOptimizerSGD, +) +from fate.ml.abc.module import Module, HeteroModule +from fate.ml.utils import predict_tools +from fate.ml.utils._convergence import converge_func_factory +from fate.ml.utils._model_param import get_initialize_func +from fate.ml.utils._model_param import serialize_param, deserialize_param + +logger = logging.getLogger(__name__) + + +class SSHELinearRegression(Module): + def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_param, + reveal_every_epoch=False, reveal_loss_freq=1, threshold=0.5): + self.learning_rate = learning_rate + self.epochs = epochs + self.batch_size = batch_size + self.tol = tol + self.early_stop = early_stop + self.learning_rate = learning_rate + self.init_param = init_param + self.threshold = threshold + self.reveal_every_epoch = reveal_every_epoch + self.reveal_loss_freq = reveal_loss_freq + + self.estimator = None + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + + def fit(self, ctx: Context, train_data: DataFrame, validate_data=None): + if ctx.is_on_host: + self.init_param["fit_intercept"] = False + if self.estimator is None: + single_estimator = SSHELREstimator( + epochs=self.epochs, + batch_size=self.batch_size, + learning_rate=self.learning_rate, + init_param=self.init_param, + reveal_every_epoch=self.reveal_every_epoch, + reveal_loss_freq=self.reveal_loss_freq, + early_stop=self.early_stop, + tol=self.tol + ) + else: + logger.info("estimator is not none, will train with warm start") + single_estimator = self.estimator + single_estimator.epochs = self.epochs + single_estimator.batch_size = self.batch_size + train_data_fit = train_data.copy() + validate_data_fit = validate_data + if validate_data: + validate_data_fit = validate_data.copy() + single_estimator.fit_single_model(ctx, train_data_fit, validate_data_fit) + self.estimator = single_estimator + + def get_model(self): + estimator = self.estimator.get_model() + return { + "data": {"estimator": estimator}, + "meta": { + "epochs": self.epochs, + "batch_size": self.batch_size, + "learning_rate": self.learning_rate, + "init_param": self.init_param, + "early_stop": self.early_stop, + # "optimizer_param": self.optimizer_param, + "reveal_every_epoch": self.reveal_every_epoch, + "reveal_loss_freq": self.reveal_loss_freq, + "tol": self.tol + }, + } + + @classmethod + def from_model(cls, model): + linr = SSHELinearRegression( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + learning_rate=model["meta"]["learning_rate"], + init_param=model["meta"]["init_param"], + reveal_every_epoch=model["meta"]["reveal_every_epoch"], + reveal_loss_freq=model["meta"]["reveal_loss_freq"], + tol=model["meta"]["tol"], + early_stop=model["meta"]["early_stop"] + ) + estimator = SSHELREstimator( + epochs=model["meta"]["epochs"], + batch_size=model["meta"]["batch_size"], + init_param=model["meta"]["init_param"], + reveal_every_epoch=model["meta"]["reveal_every_epoch"], + reveal_loss_freq=model["meta"]["reveal_loss_freq"], + tol=model["meta"]["tol"], + early_stop=model["meta"]["early_stop"] + ) + estimator.restore(model["data"]["estimator"]) + linr.estimator = estimator + + return linr + + def predict(self, ctx: Context, test_data: DataFrame) -> DataFrame: + prob = self.estimator.predict(ctx, test_data) + return prob + + +class SSHELREstimator(HeteroModule): + def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate=None, init_param=None, + reveal_every_epoch=True, reveal_loss_freq=3, early_stop=None, tol=None): + self.epochs = epochs + self.batch_size = batch_size + self.optimizer = optimizer + self.lr = learning_rate + self.init_param = init_param + self.reveal_every_epoch = reveal_every_epoch + self.reveal_loss_freq = reveal_loss_freq + self.early_stop = early_stop + self.tol = tol + + self.w = None + self.start_epoch = 0 + self.end_epoch = -1 + self.is_converged = False + self.header = None + self.converge_func = None + if early_stop is not None: + self.converge_func = converge_func_factory(self.early_stop, self.tol) + + def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: DataFrame) -> None: + rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank + + if self.w is None: + initialize_func = get_initialize_func(**self.init_param) + else: + initialize_func = lambda x: self.w + if self.init_param.get("fit_intercept"): + train_data["intercept"] = 1.0 + layer = SSHELinearRegressionLayer( + ctx, + in_features_a=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_a), + in_features_b=ctx.mpc.option_call(lambda: train_data.shape[1], dst=rank_b), + out_features=1, + rank_a=rank_a, + rank_b=rank_b, + wa_init_fn=initialize_func, + wb_init_fn=initialize_func, + ) + loss_fn = SSHELinearRegressionLossLayer(ctx, rank_a=rank_a, rank_b=rank_b) + optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) + wa = layer.wa + wb = layer.wb + if ctx.is_on_guest: + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=False) + else: + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + if self.early_stop == "weight_diff": + if ctx.is_on_guest: + self.converge_func.set_pre_weight(wb.get_plain_text(dst=rank_b)) + else: + self.converge_func.set_pre_weight(wa.get_plain_text(dst=rank_a)) + for i, epoch_ctx in ctx.on_iterations.ctxs_range(self.epochs): + epoch_loss = None + logger.info(f"self.optimizer set epoch {i}") + for batch_ctx, batch_data in epoch_ctx.on_batches.ctxs_zip(batch_loader): + h = batch_data.x + y = batch_ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) + z = layer(h) + loss = loss_fn(z, y) + if i % self.reveal_loss_freq == 0: + if epoch_loss is None: + epoch_loss = loss.get() + else: + epoch_loss += loss.get() + loss.backward() + optimizer.step() + if epoch_loss is not None: + epoch_ctx.metrics.log_loss("linr_loss", epoch_loss.tolist()) + if self.reveal_every_epoch: + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + if ctx.is_on_guest: + if self.early_stop == "weight_diff": + if self.reveal_every_epoch: + wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) + w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] + self.converge_func.set_pre_weight(wb_p) + if w_diff < self.tol: + self.is_converged = True + else: + raise ValueError(f"early stop {self.early_stop} is not supported when " + f"reveal_every_epoch is False") + else: + if i % self.reveal_loss_freq == 0: + self.is_converged = self.converge_func.is_converge(epoch_loss) + epoch_ctx.hosts.put("converge_flag", self.is_converged) + else: + if self.early_stop == "weight_diff": + if self.reveal_every_epoch: + wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) + epoch_ctx.guest.put("wa_p_delta", wa_p_delta) + self.converge_func.set_pre_weight(wa_p) + else: + raise ValueError(f"early stop {self.early_stop} is not supported when " + f"reveal_every_epoch is False") + self.is_converged = epoch_ctx.guest.get("converge_flag") + if self.is_converged: + self.end_epoch = i + break + if not self.is_converged: + self.end_epoch = self.epochs + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + if ctx.is_on_host: + self.w = wa_p + else: + self.w = wb_p + + def predict(self, ctx, test_data): + pred_df = test_data.create_frame(with_label=True, with_weight=False) + if ctx.is_on_guest: + if self.init_param.get("fit_intercept"): + test_data["intercept"] = 1.0 + X = test_data.values.as_tensor() + pred = torch.matmul(X, self.w) + for h_pred in ctx.hosts.get("h_pred"): + pred += h_pred + pred_df[predict_tools.PREDICT_SCORE] = pred + predict_result = predict_tools.compute_predict_details(pred_df, task_type=predict_tools.REGRESSION) + return predict_result + else: + X = test_data.values.as_tensor() + output = torch.matmul(X, self.w) + ctx.guest.put("h_pred", output) + + def get_model(self): + param = serialize_param(self.w, self.init_param.get("fit_intercept")) + return { + "param": param, + # "optimizer": self.optimizer.state_dict(), + "end_epoch": self.end_epoch, + "is_converged": self.is_converged, + "fit_intercept": self.init_param.get("fit_intercept"), + "header": self.header, + "lr": self.lr + } + + def restore(self, model): + self.w = deserialize_param(model["param"], model["fit_intercept"]) + self.end_epoch = model["end_epoch"] + self.is_converged = model["is_converged"] + self.header = model["header"] + self.init_param["fit_intercept"] = model["fit_intercept"] + self.lr = model["lr"] + # self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/mpc/sshe_lr.py b/python/fate/ml/glm/hetero/sshe/sshe_lr.py similarity index 74% rename from python/fate/ml/mpc/sshe_lr.py rename to python/fate/ml/glm/hetero/sshe/sshe_lr.py index 47bdd16591..330da786a7 100644 --- a/python/fate/ml/mpc/sshe_lr.py +++ b/python/fate/ml/glm/hetero/sshe/sshe_lr.py @@ -1,3 +1,18 @@ +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import torch @@ -9,11 +24,11 @@ SSHELogisticRegressionLossLayer, SSHEOptimizerSGD, ) +from fate.ml.abc.module import Module, HeteroModule from fate.ml.utils import predict_tools from fate.ml.utils._convergence import converge_func_factory from fate.ml.utils._model_param import get_initialize_func from fate.ml.utils._model_param import serialize_param, deserialize_param -from ..abc.module import Module, HeteroModule logger = logging.getLogger(__name__) @@ -36,6 +51,14 @@ def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_para self.ovr = False self.labels = None + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.estimator.batch_size = batch_size + + def set_epochs(self, epochs): + self.epochs = epochs + self.estimator.epochs = epochs + def fit(self, ctx: Context, train_data: DataFrame, validate_data=None): if ctx.is_on_guest: train_data_binarized_label = train_data.label.get_dummies() @@ -45,6 +68,7 @@ def fit(self, ctx: Context, train_data: DataFrame, validate_data=None): if self.labels is None: self.labels = sorted(labels) else: + self.init_param["fit_intercept"] = False label_count = ctx.guest.get("label_count") if label_count > 2 or self.ovr: logger.info(f"OVR data provided, will train OVR models.") @@ -121,6 +145,7 @@ def get_model(self): "batch_size": self.batch_size, "learning_rate": self.learning_rate, "init_param": self.init_param, + "early_stop": self.early_stop, # "optimizer_param": self.optimizer_param, "labels": self.labels, "ovr": self.ovr, @@ -158,7 +183,7 @@ def from_model(cls, model): reveal_every_epoch=model["meta"]["reveal_every_epoch"], reveal_loss_freq=model["meta"]["reveal_loss_freq"], tol=model["meta"]["tol"], - early_stop=model["meta"]["early_stop"] + early_stop=model["meta"]["early_stop"], ) estimator.restore(d) lr.estimator[int(label)] = estimator @@ -170,7 +195,7 @@ def from_model(cls, model): reveal_every_epoch=model["meta"]["reveal_every_epoch"], reveal_loss_freq=model["meta"]["reveal_loss_freq"], tol=model["meta"]["tol"], - early_stop=model["meta"]["early_stop"] + early_stop=model["meta"]["early_stop"], ) estimator.restore(all_estimator) lr.estimator = estimator @@ -178,25 +203,34 @@ def from_model(cls, model): return lr def predict(self, ctx, test_data) -> DataFrame: - pred_df = test_data.create_frame(with_label=True, with_weight=False) - if self.ovr: - pred_score = test_data.create_frame(with_label=False, with_weight=False) - for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): - estimator = self.estimator[i] - pred = estimator.predict(class_ctx, test_data) - pred_score[str(self.labels[i])] = pred - pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)]) - predict_result = predict_tools.compute_predict_details( - pred_df, task_type=predict_tools.MULTI, classes=self.labels - ) + if ctx.is_on_guest: + pred_df = test_data.create_frame(with_label=True, with_weight=False) + if self.ovr: + pred_score = test_data.create_frame(with_label=False, with_weight=False) + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): + estimator = self.estimator[i] + pred = estimator.predict(class_ctx, test_data) + pred_score[str(self.labels[i])] = pred + pred_df[predict_tools.PREDICT_SCORE] = pred_score.apply_row(lambda v: [list(v)]) + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.MULTI, classes=self.labels + ) + else: + predict_score = self.estimator.predict(ctx, test_data) + pred_df[predict_tools.PREDICT_SCORE] = predict_score + predict_result = predict_tools.compute_predict_details( + pred_df, task_type=predict_tools.BINARY, classes=self.labels, threshold=self.threshold + ) + return predict_result else: - predict_score = self.estimator.predict(ctx, test_data) - pred_df[predict_tools.PREDICT_SCORE] = predict_score - predict_result = predict_tools.compute_predict_details( - pred_df, task_type=predict_tools.BINARY, classes=self.labels, threshold=self.threshold - ) + if self.ovr: + for i, class_ctx in ctx.sub_ctx("class").ctxs_range(len(self.labels)): + estimator = self.estimator[i] + estimator.predict(class_ctx, test_data) + else: + self.estimator.predict(ctx, test_data) + - return predict_result class SSHELREstimator(HeteroModule): @@ -223,7 +257,12 @@ def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate=N def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: DataFrame) -> None: rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank - initialize_func = get_initialize_func(**self.init_param) + if ctx.is_on_host: + self.init_param["fit_intercept"] = False + if self.w is None: + initialize_func = get_initialize_func(**self.init_param) + else: + initialize_func = lambda x: self.w if self.init_param.get("fit_intercept"): train_data["intercept"] = 1.0 layer = SSHELogisticRegressionLayer( @@ -240,28 +279,43 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) wa = layer.wa wb = layer.wb - batch_loader = dataframe.DataLoader( - train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=False) + if ctx.is_on_guest: + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="guest", sync_arbiter=False) + else: + batch_loader = dataframe.DataLoader( + train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + if self.early_stop == "weight_diff": + if ctx.is_on_guest: + self.converge_func.set_pre_weight(wb.get_plain_text(dst=rank_b)) + else: + self.converge_func.set_pre_weight(wa.get_plain_text(dst=rank_a)) for i, epoch_ctx in ctx.on_iterations.ctxs_range(self.epochs): - epoch_loss = 0 + epoch_loss = None logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in epoch_ctx.on_batches.ctxs_zip(batch_loader): h = batch_data.x - y = ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) + y = batch_ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) z = layer(h) loss = loss_fn(z, y) if i % self.reveal_loss_freq == 0: - epoch_loss += loss.get() + if epoch_loss is None: + epoch_loss = loss.get() + else: + epoch_loss += loss.get() loss.backward() optimizer.step() + if epoch_loss is not None: + epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist()) if self.reveal_every_epoch: wa_p = wa.get_plain_text(dst=rank_a) wb_p = wb.get_plain_text(dst=rank_b) if ctx.is_on_guest: if self.early_stop == "weight_diff": if self.reveal_every_epoch: - wa_p_delta = self.converge_func.compute_weight_diff(wa_p) - w_diff = ctx.guest.put("wa_p_delta", wa_p_delta) + wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) + w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] + self.converge_func.set_pre_weight(wb_p) if w_diff < self.tol: self.is_converged = True else: @@ -269,17 +323,14 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data f"reveal_every_epoch is False") else: if i % self.reveal_loss_freq == 0: - if epoch_loss is not None: - print(f"epoch {i} loss: {epoch_loss.tolist()}") - epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist()) - if self.early_stop != "weight_diff": - self.is_converged = self.converge_func.is_converge(epoch_loss) + self.is_converged = self.converge_func.is_converge(epoch_loss) epoch_ctx.hosts.put("converge_flag", self.is_converged) else: if self.early_stop == "weight_diff": if self.reveal_every_epoch: - wb_p_delta = self.converge_func.compute_weight_diff(wb_p) - ctx.guest.put("wb_p_delta", wb_p_delta) + wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) + epoch_ctx.guest.put("wa_p_delta", wa_p_delta) + self.converge_func.set_pre_weight(wa_p) self.is_converged = epoch_ctx.guest.get("converge_flag") if self.is_converged: self.end_epoch = i @@ -318,7 +369,8 @@ def get_model(self): "end_epoch": self.end_epoch, "is_converged": self.is_converged, "fit_intercept": self.init_param.get("fit_intercept"), - "header": self.header + "header": self.header, + "lr": self.lr } def restore(self, model): @@ -327,4 +379,5 @@ def restore(self, model): self.is_converged = model["is_converged"] self.header = model["header"] self.init_param["fit_intercept"] = model["fit_intercept"] + self.lr = model["lr"] # self.optimizer.load_state_dict(model["optimizer"]) diff --git a/python/fate/ml/mpc/sshe_linr.py b/python/fate/ml/mpc/sshe_linr.py deleted file mode 100644 index 671f4cab53..0000000000 --- a/python/fate/ml/mpc/sshe_linr.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging - - -from fate.arch import Context -from ..abc.module import Module -from fate.arch.dataframe import DataFrame -from fate.arch.protocol.mpc.nn.sshe.linr_layer import ( - SSHELinearRegressionLayer, - SSHELinearRegressionLossLayer, - SSHEOptimizerSGD, -) - -logger = logging.getLogger(__name__) - - -class SSHELinearRegression(Module): - def __init__(self, lr=0.05): - self.lr = lr - - def fit(self, ctx: Context, input_data: DataFrame) -> None: - rank_a, rank_b = ctx.hosts[0].rank, ctx.guest.rank - y = ctx.mpc.cond_call(lambda: input_data.label.as_tensor(), lambda: None, dst=rank_b) - h = input_data.as_tensor() - # generator = torch.Generator().manual_seed(0) - layer = SSHELinearRegressionLayer( - ctx, - in_features_a=ctx.mpc.option_call(lambda: h.shape[1], dst=rank_a), - in_features_b=ctx.mpc.option_call(lambda: h.shape[1], dst=rank_b), - out_features=1, - rank_a=rank_a, - rank_b=rank_b, - # generator=generator, - ) - loss_fn = SSHELinearRegressionLossLayer(ctx, rank_a=rank_a, rank_b=rank_b) - optimizer = SSHEOptimizerSGD(ctx, layer.parameters(), lr=self.lr) - - for i in range(20): - z = layer(h) - loss = loss_fn(z, y) - if i % 3 == 0: - logger.info(f"loss: {loss.get()}") - loss.backward() - optimizer.step() diff --git a/python/fate/ml/utils/_convergence.py b/python/fate/ml/utils/_convergence.py index 05498ecf52..a2b4658644 100644 --- a/python/fate/ml/utils/_convergence.py +++ b/python/fate/ml/utils/_convergence.py @@ -75,8 +75,16 @@ def __init__(self, eps): super().__init__(eps=eps) self.pre_weight = None - def is_converge(self, delta_weight, weight=None): + def set_pre_weight(self, weight): + self.pre_weight = weight + + @staticmethod + def compute_weight_diff(delta_weight): weight_diff = torch.linalg.norm(delta_weight, 2) + return weight_diff + + def is_converge(self, delta_weight, weight=None): + weight_diff = self.compute_weight_diff(delta_weight) if weight is None: # avoid tensor[bool] if weight_diff < self.eps: diff --git a/python/fate/ml/utils/_model_param.py b/python/fate/ml/utils/_model_param.py index 4aa7f42084..ff1e063e89 100644 --- a/python/fate/ml/utils/_model_param.py +++ b/python/fate/ml/utils/_model_param.py @@ -16,6 +16,28 @@ import torch +def get_initialize_func(**kwargs): + method = kwargs["method"] + random_state = kwargs.get("random_state", None) + + if method == 'zeros': + return lambda shape: torch.zeros(shape) + elif method == 'ones': + return lambda shape: torch.ones(shape) + elif method == 'random': + if random_state is not None: + generator = torch.Generator().manual_seed(random_state) + return lambda shape: torch.randn(shape, generator=generator) + return lambda shape: torch.randn(shape) + elif method == 'random_uniform': + if random_state is not None: + generator = torch.Generator().manual_seed(random_state) + return lambda shape: torch.rand(shape, generator=generator) + return lambda shape: torch.rand(shape) + else: + raise NotImplementedError(f"Unknown initialization method: {method}") + + def initialize_param(coef_len, **kwargs): param_len = coef_len method = kwargs["method"] From 8650f3da0c91a6b1145a3a81bd6a970f69d46019 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Thu, 7 Dec 2023 16:34:25 +0800 Subject: [PATCH 14/42] add sshe lr & sshe linr (#5227) Signed-off-by: Yu Wu --- .../arch/protocol/mpc/primitives/arithmetic.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/fate/arch/protocol/mpc/primitives/arithmetic.py b/python/fate/arch/protocol/mpc/primitives/arithmetic.py index b9f1df95c9..389baed0ba 100644 --- a/python/fate/arch/protocol/mpc/primitives/arithmetic.py +++ b/python/fate/arch/protocol/mpc/primitives/arithmetic.py @@ -5,6 +5,21 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# +# Copyright 2023 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging # dependencies: From 74b6edb739391b49c4958aadbb2f17dd152b8cd9 Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 17:33:00 +0800 Subject: [PATCH 15/42] add serdes Signed-off-by: sagewe --- python/fate/arch/computing/serdes/__init__.py | 31 +++++++ .../arch/computing/serdes/_integer_serdes.py | 13 +++ .../serdes/_restricted_caught_miss_serdes.py | 44 ++++++++++ .../computing/serdes/_restricted_serdes.py | 82 +++++++++++++++++++ .../computing/serdes/_unrestricted_serdes.py | 19 +++++ python/fate/arch/computing/table.py | 2 +- python/fate/arch/config/_config.py | 5 ++ python/fate/arch/context/_context.py | 1 + python/fate/arch/unify/serdes.py | 40 --------- 9 files changed, 196 insertions(+), 41 deletions(-) create mode 100644 python/fate/arch/computing/serdes/__init__.py create mode 100644 python/fate/arch/computing/serdes/_integer_serdes.py create mode 100644 python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py create mode 100644 python/fate/arch/computing/serdes/_restricted_serdes.py create mode 100644 python/fate/arch/computing/serdes/_unrestricted_serdes.py delete mode 100644 python/fate/arch/unify/serdes.py diff --git a/python/fate/arch/computing/serdes/__init__.py b/python/fate/arch/computing/serdes/__init__.py new file mode 100644 index 0000000000..978a69a307 --- /dev/null +++ b/python/fate/arch/computing/serdes/__init__.py @@ -0,0 +1,31 @@ +from fate.arch.config import cfg + + +def get_serdes_by_type(serdes_type: int): + if serdes_type == 0: + if cfg.safety.serdes.restricted_type == "unrestricted": + from ._unrestricted_serdes import get_unrestricted_serdes + + return get_unrestricted_serdes() + elif cfg.safety.serdes.restricted_type == "restricted": + from ._restricted_serdes import get_restricted_serdes + + return get_restricted_serdes() + elif cfg.safety.serdes.restricted_type == "restricted_catch_miss": + from ._restricted_caught_miss_serdes import get_restricted_catch_miss_serdes + + return get_restricted_catch_miss_serdes() + else: + raise ValueError(f"restricted type `{cfg.safety.serdes.restricted_type}` not supported") + elif serdes_type == 1: + from ._integer_serdes import get_integer_serdes + + return get_integer_serdes() + else: + raise ValueError(f"serdes type `{serdes_type}` not supported") + + +def dump_miss(path): + from ._restricted_caught_miss_serdes import dump_miss + + dump_miss(path) diff --git a/python/fate/arch/computing/serdes/_integer_serdes.py b/python/fate/arch/computing/serdes/_integer_serdes.py new file mode 100644 index 0000000000..0df5880012 --- /dev/null +++ b/python/fate/arch/computing/serdes/_integer_serdes.py @@ -0,0 +1,13 @@ +def get_integer_serdes(): + return IntegerSerdes() + + +class IntegerSerdes: + def __init__(self): + ... + + def serialize(self, obj) -> bytes: + return obj.to_bytes(8, "big") + + def deserialize(self, bytes) -> object: + return int.from_bytes(bytes, "big") diff --git a/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py b/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py new file mode 100644 index 0000000000..c3604399f4 --- /dev/null +++ b/python/fate/arch/computing/serdes/_restricted_caught_miss_serdes.py @@ -0,0 +1,44 @@ +import io +import pickle + +from ruamel import yaml + +from ._restricted_serdes import RestrictedUnpickler + + +def get_restricted_catch_miss_serdes(): + return WhitelistCatchRestrictedSerdes + + +class WhitelistCatchRestrictedSerdes: + @classmethod + def serialize(cls, obj) -> bytes: + return pickle.dumps(obj) + + @classmethod + def deserialize(cls, bytes) -> object: + return RestrictedCatchUnpickler(io.BytesIO(bytes)).load() + + +class RestrictedCatchUnpickler(RestrictedUnpickler): + caught_miss = {} + + def find_class(self, module, name): + try: + return super().find_class(module, name) + except pickle.UnpicklingError: + if (module, name) not in self.caught_miss: + if module not in self.caught_miss: + self.caught_miss[module] = set() + self.caught_miss[module].add(name) + return self._load(module, name) + + @classmethod + def dump_miss(cls, path): + with open(path, "w") as f: + yaml.dump({module: list(names) for module, names in cls.caught_miss.items()}, f) + + +def dump_miss(path): + RestrictedCatchUnpickler.dump_miss(path) + diff --git a/python/fate/arch/computing/serdes/_restricted_serdes.py b/python/fate/arch/computing/serdes/_restricted_serdes.py new file mode 100644 index 0000000000..0c1838fc40 --- /dev/null +++ b/python/fate/arch/computing/serdes/_restricted_serdes.py @@ -0,0 +1,82 @@ +import importlib +import io +import pickle + +from ruamel import yaml + + +def get_restricted_serdes(): + return WhitelistRestrictedSerdes + + +class WhitelistRestrictedSerdes: + @classmethod + def serialize(cls, obj) -> bytes: + return pickle.dumps(obj) + + @classmethod + def deserialize(cls, bytes) -> object: + return RestrictedUnpickler(io.BytesIO(bytes)).load() + + +class RestrictedUnpickler(pickle.Unpickler): + def _load(self, module, name): + try: + return super().find_class(module, name) + except: + return getattr(importlib.import_module(module), name) + + def find_class(self, module, name): + if name in Whitelist.get_whitelist().get(module, set()): + return self._load(module, name) + else: + for m in Whitelist.get_whitelist_glob(): + if module.startswith(m): + return self._load(module, name) + raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}") + + +class Whitelist: + loaded = False + deserialize_whitelist = {} + deserialize_glob_whitelist = set() + + @classmethod + def get_whitelist_glob(cls): + if not cls.loaded: + cls.load_deserialize_whitelist() + return cls.deserialize_glob_whitelist + + @classmethod + def get_whitelist(cls): + if not cls.loaded: + cls.load_deserialize_whitelist() + return cls.deserialize_whitelist + + @classmethod + def get_whitelist_path(cls): + import os.path + + return os.path.abspath( + os.path.join( + __file__, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + os.path.pardir, + "configs", + "whitelist.yaml", + ) + ) + + @classmethod + def load_deserialize_whitelist(cls): + with open(cls.get_whitelist_path()) as f: + for k, v in yaml.load(f, Loader=yaml.SafeLoader).items(): + if k.endswith("*"): + cls.deserialize_glob_whitelist.add(k[:-1]) + else: + cls.deserialize_whitelist[k] = set(v) + cls.loaded = True diff --git a/python/fate/arch/computing/serdes/_unrestricted_serdes.py b/python/fate/arch/computing/serdes/_unrestricted_serdes.py new file mode 100644 index 0000000000..595e57cebb --- /dev/null +++ b/python/fate/arch/computing/serdes/_unrestricted_serdes.py @@ -0,0 +1,19 @@ +import os +import pickle + + +def get_unrestricted_serdes(): + if True or os.environ.get("SERDES_DEBUG_MODE") == "1": + return UnrestrictedSerdes + else: + raise PermissionError("UnsafeSerdes is not allowed in production mode") + + +class UnrestrictedSerdes: + @staticmethod + def serialize(obj) -> bytes: + return pickle.dumps(obj) + + @staticmethod + def deserialize(bytes) -> object: + return pickle.loads(bytes) diff --git a/python/fate/arch/computing/table.py b/python/fate/arch/computing/table.py index ad8b124c19..36e8f68628 100644 --- a/python/fate/arch/computing/table.py +++ b/python/fate/arch/computing/table.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Tuple, Iterable, Generic, TypeVar, Optional from fate.arch.unify.partitioner import get_partitioner_by_type -from fate.arch.unify.serdes import get_serdes_by_type +from fate.arch.computing.serdes import get_serdes_by_type from fate.arch.utils.trace import auto_trace from ..unify import URI import functools diff --git a/python/fate/arch/config/_config.py b/python/fate/arch/config/_config.py index 4381ba8fa2..aa5aa63502 100644 --- a/python/fate/arch/config/_config.py +++ b/python/fate/arch/config/_config.py @@ -66,3 +66,8 @@ def get_option(self, options, key, default=...): raise ValueError(f"{key} not in {options} or {self.config}") else: return default + + @property + def safety(self): + return self.config.safety + diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 5e5c16b191..644dc78073 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -23,6 +23,7 @@ from ._namespace import NS, default_ns from ..unify import device from fate.arch.utils.trace import auto_trace +from fate.arch.config import cfg logger = logging.getLogger(__name__) diff --git a/python/fate/arch/unify/serdes.py b/python/fate/arch/unify/serdes.py deleted file mode 100644 index 9833946ab4..0000000000 --- a/python/fate/arch/unify/serdes.py +++ /dev/null @@ -1,40 +0,0 @@ -import pickle -import os - - -class UnsafeSerdes: - def __init__(self): - ... - - def serialize(self, obj) -> bytes: - return pickle.dumps(obj) - - def deserialize(self, bytes) -> object: - return pickle.loads(bytes) - - -class IntegerSerdes: - def __init__(self): - ... - - def serialize(self, obj) -> bytes: - return obj.to_bytes(8, "big") - - def deserialize(self, bytes) -> object: - return int.from_bytes(bytes, "big") - - -def get_unsafe_serdes(): - if True or os.environ.get("SERDES_DEBUG_MODE") == "1": - return UnsafeSerdes() - else: - raise PermissionError("UnsafeSerdes is not allowed in production mode") - - -def get_serdes_by_type(serdes_type: int): - if serdes_type == 0: - return get_unsafe_serdes() - elif serdes_type == 1: - return IntegerSerdes() - else: - raise ValueError(f"serdes type `{serdes_type}` not supported") From 35c091db90845280cc05beaa1716a427557aa0be Mon Sep 17 00:00:00 2001 From: sagewe Date: Thu, 7 Dec 2023 21:31:18 +0800 Subject: [PATCH 16/42] add serdes Signed-off-by: sagewe --- configs/default.yaml | 5 ++ configs/whitelist.yaml | 1 + .../arch/computing/serdes/_safe_serdes.py | 83 +++++++++++++++++++ python/fate/arch/context/_federation.py | 7 ++ 4 files changed, 96 insertions(+) create mode 100644 configs/whitelist.yaml create mode 100644 python/fate/arch/computing/serdes/_safe_serdes.py diff --git a/configs/default.yaml b/configs/default.yaml index 3f8b37f881..438828fef2 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -45,3 +45,8 @@ nn: protocol: "layer_estimation" skip_loss_forward: True cache_pred_size: True + +safety: + serdes: + # supported types: unrestricted, restricted, restricted_catch_miss + restricted_type: "unrestricted" \ No newline at end of file diff --git a/configs/whitelist.yaml b/configs/whitelist.yaml new file mode 100644 index 0000000000..e5873408d4 --- /dev/null +++ b/configs/whitelist.yaml @@ -0,0 +1 @@ +fate: "*" \ No newline at end of file diff --git a/python/fate/arch/computing/serdes/_safe_serdes.py b/python/fate/arch/computing/serdes/_safe_serdes.py new file mode 100644 index 0000000000..fd1059e9c9 --- /dev/null +++ b/python/fate/arch/computing/serdes/_safe_serdes.py @@ -0,0 +1,83 @@ +import enum +import struct +from functools import singledispatch + + +class SerdeObjectTypes(enum.IntEnum): + INT = 0 + FLOAT = 1 + STRING = 2 + BYTES = 3 + LIST = 4 + DICT = 5 + TUPLE = 6 + + +_deserializer_registry = {} + + +def _register_deserializer(obj_type_enum): + def _register(deserializer_func): + _deserializer_registry[obj_type_enum] = deserializer_func + return deserializer_func + + return _register + + +def _dispatch_deserializer(obj_type_enum): + return _deserializer_registry[obj_type_enum] + + +class SafeSerdes(object): + @staticmethod + def serialize(obj): + obj_type, obj_bytes = serialize_obj(obj) + return struct.pack("!h", obj_type) + obj_bytes + + @staticmethod + def deserialize(raw_bytes): + (obj_type,) = struct.unpack("!h", raw_bytes[:2]) + return _dispatch_deserializer(obj_type)(raw_bytes[2:]) + + +@singledispatch +def serialize_obj(obj): + raise NotImplementedError("Unsupported type: {}".format(type(obj))) + + +@serialize_obj.register(int) +def _(obj): + return SerdeObjectTypes.INT, struct.pack("!q", obj) + + +@_register_deserializer(SerdeObjectTypes.INT) +def _(raw_bytes): + return struct.unpack("!q", raw_bytes)[0] + + +@serialize_obj.register(float) +def _(obj): + return SerdeObjectTypes.FLOAT, struct.pack("!d", obj) + + +@_register_deserializer(SerdeObjectTypes.FLOAT) +def _(raw_bytes): + return struct.unpack("!d", raw_bytes)[0] + + +@serialize_obj.register(str) +def _(obj): + utf8_str = obj.encode("utf-8") + return SerdeObjectTypes.STRING, struct.pack("!I", len(utf8_str)) + utf8_str + + +@_register_deserializer(SerdeObjectTypes.STRING) +def _(raw_bytes): + length = struct.unpack("!I", raw_bytes[:4])[0] + return raw_bytes[4 : 4 + length].decode("utf-8") + + +if __name__ == "__main__": + print(SafeSerdes.deserialize(SafeSerdes.serialize(1))) + print(SafeSerdes.deserialize(SafeSerdes.serialize(1.0))) + print(SafeSerdes.deserialize(SafeSerdes.serialize("hello"))) diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index d4cd2706be..c2e3189859 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -14,6 +14,7 @@ # limitations under the License. import io import pickle +import logging import struct import typing from typing import Any, List, Tuple, TypeVar, Union @@ -23,6 +24,7 @@ from ..computing import is_table from ..federation._gc import IterationGC +logger = logging.getLogger(__name__) T = TypeVar("T") if typing.TYPE_CHECKING: @@ -308,6 +310,11 @@ def persistent_load(self, pid: Any) -> Any: if isinstance(pid, _ContextPersistentId): return self._ctx + # def load(self): + # out = super().load() + # logger.error(f"unpickled: {out.__class__.__module__}.{out.__class__.__name__}") + # return out + @classmethod def pull( cls, From 71e5aded76481e06c400491f987a7da0ddd7b163 Mon Sep 17 00:00:00 2001 From: sagewe Date: Fri, 8 Dec 2023 18:17:32 +0800 Subject: [PATCH 17/42] add unresolved data type Signed-off-by: sagewe --- python/fate/components/components/reader.py | 38 +++---------------- python/fate/components/core/_cpn_reexport.py | 4 ++ .../core/component_desc/__init__.py | 4 ++ .../core/component_desc/artifacts/__init__.py | 4 ++ .../component_desc/artifacts/data/__init__.py | 11 +++++- .../artifacts/data/_unresolved.py | 38 +++++++++++++++++++ .../components/core/essential/__init__.py | 1 + .../core/essential/_artifact_type.py | 6 +++ 8 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 python/fate/components/core/component_desc/artifacts/data/_unresolved.py diff --git a/python/fate/components/components/reader.py b/python/fate/components/components/reader.py index c35e377ffd..d686f8f198 100644 --- a/python/fate/components/components/reader.py +++ b/python/fate/components/components/reader.py @@ -12,41 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from fate.components.core import GUEST, HOST, Role, cpn +from fate.components.core import GUEST, HOST, ARBITER, Role, cpn -@cpn.component(roles=[GUEST, HOST]) +@cpn.component(roles=[GUEST, HOST, ARBITER]) def reader( ctx, role: Role, - path: cpn.parameter(type=str, default=None, optional=False), - format: cpn.parameter(type=str, default="csv", optional=False), - sample_id_name: cpn.parameter(type=str, default=None, optional=True), - match_id_name: cpn.parameter(type=str, default=None, optional=True), - delimiter: cpn.parameter(type=str, default=",", optional=True), - label_name: cpn.parameter(type=str, default=None, optional=True), - label_type: cpn.parameter(type=str, default="float32", optional=True), - dtype: cpn.parameter(type=str, default="float32", optional=True), - output_data: cpn.dataframe_output(roles=[GUEST, HOST]), + name: cpn.parameter(type=str, default=None, optional=False), + namespace: cpn.parameter(type=str, default=None, optional=False), + data_output: cpn.data_unresolved_output(), ): - if format == "csv": - data_meta = DataframeArtifact( - uri=path, - name="data", - metadata=dict( - format=format, - sample_id_name=sample_id_name, - match_id_name=match_id_name, - delimiter=delimiter, - label_name=label_name, - label_type=label_type, - dtype=dtype, - ), - ) - elif format == "raw_table": - data_meta = DataframeArtifact(uri=path, name="data", metadata=dict(format=format)) - else: - raise ValueError(f"Reader does not support format={format}") - - data = ctx.reader(data_meta).read_dataframe() - ctx.writer(output_data).write_dataframe(data) + data_output.write_metadata({}, name=name, namespace=namespace) diff --git a/python/fate/components/core/_cpn_reexport.py b/python/fate/components/core/_cpn_reexport.py index 6ed3882141..e40495246f 100644 --- a/python/fate/components/core/_cpn_reexport.py +++ b/python/fate/components/core/_cpn_reexport.py @@ -25,6 +25,8 @@ dataframe_inputs, dataframe_output, dataframe_outputs, + data_unresolved_output, + data_unresolved_outputs, json_model_input, json_model_inputs, json_model_output, @@ -63,6 +65,8 @@ def wrapper(roles: Optional[List[Role]] = None, desc="", optional=False) -> "Typ "data_directory_output", "data_directory_outputs", "data_directory_inputs", + "data_unresolved_output", + "data_unresolved_outputs", "json_model_output", "json_model_outputs", "json_model_input", diff --git a/python/fate/components/core/component_desc/__init__.py b/python/fate/components/core/component_desc/__init__.py index c1cec6a28c..f56d90b5dd 100644 --- a/python/fate/components/core/component_desc/__init__.py +++ b/python/fate/components/core/component_desc/__init__.py @@ -22,6 +22,8 @@ data_directory_inputs, data_directory_output, data_directory_outputs, + data_unresolved_output, + data_unresolved_outputs, dataframe_input, dataframe_inputs, dataframe_output, @@ -55,6 +57,8 @@ "data_directory_output", "data_directory_outputs", "data_directory_inputs", + "data_unresolved_output", + "data_unresolved_outputs", "json_model_output", "json_model_outputs", "json_model_input", diff --git a/python/fate/components/core/component_desc/artifacts/__init__.py b/python/fate/components/core/component_desc/artifacts/__init__.py index 0e3ff1f9c6..0d50344dc0 100644 --- a/python/fate/components/core/component_desc/artifacts/__init__.py +++ b/python/fate/components/core/component_desc/artifacts/__init__.py @@ -16,6 +16,8 @@ dataframe_outputs, table_input, table_inputs, + data_unresolved_output, + data_unresolved_outputs, ) from .metric import json_metric_output, json_metric_outputs from .model import ( @@ -50,6 +52,8 @@ "data_directory_inputs", "data_directory_output", "data_directory_outputs", + "data_unresolved_output", + "data_unresolved_outputs", "json_metric_output", "json_metric_outputs", ] diff --git a/python/fate/components/core/component_desc/artifacts/data/__init__.py b/python/fate/components/core/component_desc/artifacts/data/__init__.py index 9cf11ce444..713aa130c0 100644 --- a/python/fate/components/core/component_desc/artifacts/data/__init__.py +++ b/python/fate/components/core/component_desc/artifacts/data/__init__.py @@ -1,6 +1,5 @@ from typing import Iterator, List, Optional, Type -from .._base_type import Role, _create_artifact_annotation from ._dataframe import DataframeArtifactDescribe, DataframeReader, DataframeWriter from ._directory import ( DataDirectoryArtifactDescribe, @@ -8,6 +7,8 @@ DataDirectoryWriter, ) from ._table import TableArtifactDescribe, TableReader, TableWriter +from ._unresolved import DataUnresolvedArtifactDescribe, DataUnresolvedReader, DataUnresolvedWriter +from .._base_type import Role, _create_artifact_annotation def dataframe_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataframeReader]: @@ -60,3 +61,11 @@ def data_directory_outputs( roles: Optional[List[Role]] = None, desc="", optional=False ) -> Type[Iterator[DataDirectoryWriter]]: return _create_artifact_annotation(False, True, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) + +def data_unresolved_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataUnresolvedWriter]: + return _create_artifact_annotation(False, False, DataUnresolvedArtifactDescribe, "data")(roles, desc, optional) + +def data_unresolved_outputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[Iterator[DataUnresolvedWriter]]: + return _create_artifact_annotation(False, True, DataUnresolvedArtifactDescribe, "data")(roles, desc, optional) diff --git a/python/fate/components/core/component_desc/artifacts/data/_unresolved.py b/python/fate/components/core/component_desc/artifacts/data/_unresolved.py new file mode 100644 index 0000000000..23ca24f9b0 --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/data/_unresolved.py @@ -0,0 +1,38 @@ +from pathlib import Path + +from fate.components.core.essential import DataUnresolvedArtifactType +from .._base_type import ( + URI, + ArtifactDescribe, + DataOutputMetadata, + Metadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + + +class DataUnresolvedWriter(_ArtifactTypeWriter[DataUnresolvedArtifactType]): + def write_metadata(self, metadata: dict, name=None, namespace=None): + self.artifact.metadata.metadata.update(metadata) + if name is not None: + self.artifact.metadata.name = name + if namespace is not None: + self.artifact.metadata.namespace = namespace + + +class DataUnresolvedReader(_ArtifactTypeReader): + def get_metadata(self): + return self.artifact.metadata.metadata + + +class DataUnresolvedArtifactDescribe(ArtifactDescribe[DataUnresolvedArtifactType, DataOutputMetadata]): + @classmethod + def get_type(cls): + return DataUnresolvedArtifactType + + def get_writer(self, config, ctx, uri: URI, type_name: str) -> DataUnresolvedWriter: + return DataUnresolvedWriter(ctx, _ArtifactType(uri=uri, metadata=DataOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx, uri: "URI", metadata: "Metadata", type_name: str) -> DataUnresolvedReader: + return DataUnresolvedReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/essential/__init__.py b/python/fate/components/core/essential/__init__.py index 5918878659..3ae90da3b1 100644 --- a/python/fate/components/core/essential/__init__.py +++ b/python/fate/components/core/essential/__init__.py @@ -1,6 +1,7 @@ from ._artifact_type import ( ArtifactType, DataDirectoryArtifactType, + DataUnresolvedArtifactType, DataframeArtifactType, JsonMetricArtifactType, JsonModelArtifactType, diff --git a/python/fate/components/core/essential/_artifact_type.py b/python/fate/components/core/essential/_artifact_type.py index 488032d476..28cada731a 100644 --- a/python/fate/components/core/essential/_artifact_type.py +++ b/python/fate/components/core/essential/_artifact_type.py @@ -25,6 +25,12 @@ class DataDirectoryArtifactType(ArtifactType): uri_types = ["file"] +class DataUnresolvedArtifactType(ArtifactType): + type_name = "data_unresolved" + path_type = "unresolved" + uri_types = ["unresolved"] + + class ModelDirectoryArtifactType(ArtifactType): type_name = "model_directory" path_type = "directory" From 1e47e79aef1af629a51f0039d1571920d201079f Mon Sep 17 00:00:00 2001 From: sagewe Date: Fri, 8 Dec 2023 22:43:54 +0800 Subject: [PATCH 18/42] add unresolved model type Signed-off-by: sagewe --- python/fate/components/core/_cpn_reexport.py | 4 +++ .../fate/components/core/_load_federation.py | 8 +++-- .../core/component_desc/__init__.py | 4 +++ .../core/component_desc/_component.py | 4 --- .../core/component_desc/_component_io.py | 1 - .../core/component_desc/artifacts/__init__.py | 4 +++ .../component_desc/artifacts/data/__init__.py | 2 ++ .../artifacts/model/__init__.py | 15 +++++++- .../artifacts/model/_unresolved.py | 36 +++++++++++++++++++ .../components/core/essential/__init__.py | 1 + .../core/essential/_artifact_type.py | 6 ++++ python/fate/components/core/spec/model.py | 3 -- 12 files changed, 77 insertions(+), 11 deletions(-) create mode 100644 python/fate/components/core/component_desc/artifacts/model/_unresolved.py diff --git a/python/fate/components/core/_cpn_reexport.py b/python/fate/components/core/_cpn_reexport.py index e40495246f..0cebe00d7c 100644 --- a/python/fate/components/core/_cpn_reexport.py +++ b/python/fate/components/core/_cpn_reexport.py @@ -38,6 +38,8 @@ parameter, table_input, table_inputs, + model_unresolved_output, + model_unresolved_outputs, ) from .essential import Role @@ -75,4 +77,6 @@ def wrapper(roles: Optional[List[Role]] = None, desc="", optional=False) -> "Typ "model_directory_outputs", "model_directory_output", "model_directory_input", + "model_unresolved_output", + "model_unresolved_outputs", ] diff --git a/python/fate/components/core/_load_federation.py b/python/fate/components/core/_load_federation.py index f98fabe959..06d7a2439e 100644 --- a/python/fate/components/core/_load_federation.py +++ b/python/fate/components/core/_load_federation.py @@ -36,14 +36,18 @@ def load_federation(federation, computing): if isinstance(federation, (OSXFederationSpec, RollSiteFederationSpec)): if isinstance(federation, OSXFederationSpec): mode = FederationMode.from_str(federation.metadata.osx_config.mode) + host = federation.metadata.osx_config.host + port = federation.metadata.osx_config.port options = dict(max_message_size=federation.metadata.osx_config.max_message_size) else: mode = FederationMode.STREAM + host = federation.metadata.rollsite_config.host + port = federation.metadata.rollsite_config.port options = {} return builder.build_osx( computing_session=computing, - host=federation.metadata.osx_config.host, - port=federation.metadata.osx_config.port, + host=host, + port=port, mode=mode, options=options, ) diff --git a/python/fate/components/core/component_desc/__init__.py b/python/fate/components/core/component_desc/__init__.py index f56d90b5dd..a6bc15d9f2 100644 --- a/python/fate/components/core/component_desc/__init__.py +++ b/python/fate/components/core/component_desc/__init__.py @@ -40,6 +40,8 @@ model_directory_outputs, table_input, table_inputs, + model_unresolved_output, + model_unresolved_outputs, ) __all__ = [ @@ -69,4 +71,6 @@ "model_directory_input", "json_metric_output", "json_metric_outputs", + "model_unresolved_output", + "model_unresolved_outputs", ] diff --git a/python/fate/components/core/component_desc/_component.py b/python/fate/components/core/component_desc/_component.py index 0651339757..0048cad749 100644 --- a/python/fate/components/core/component_desc/_component.py +++ b/python/fate/components/core/component_desc/_component.py @@ -232,7 +232,6 @@ def dump_yaml(self, stream=None): def predict( self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None ): - if roles is None: roles = [] @@ -241,7 +240,6 @@ def predict( def train( self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None ): - if roles is None: roles = [] @@ -250,7 +248,6 @@ def train( def cross_validation( self, roles: List = None, provider: Optional[str] = None, version: Optional[str] = None, description=None ): - if roles is None: roles = [] @@ -326,7 +323,6 @@ def component( def _component(name, roles, provider, version, description, is_subcomponent): def decorator(f): - cpn_name = name or f.__name__.lower() if isinstance(f, Component): raise TypeError("Attempted to convert a callback into a component_desc twice.") diff --git a/python/fate/components/core/component_desc/_component_io.py b/python/fate/components/core/component_desc/_component_io.py index 4c614fe34c..3f62e4d5ca 100644 --- a/python/fate/components/core/component_desc/_component_io.py +++ b/python/fate/components/core/component_desc/_component_io.py @@ -134,7 +134,6 @@ def _handle_output(self, ctx, component, arg, stage, role, config): (self.output_model, component.artifacts.model_outputs), (self.output_metric, component.artifacts.metric_outputs), ]: - if allowed_artifacts := artifacts.get(arg): if allowed_artifacts.is_active_for(stage, role): apply_spec: ArtifactOutputApplySpec = config.output_artifacts.get(arg) diff --git a/python/fate/components/core/component_desc/artifacts/__init__.py b/python/fate/components/core/component_desc/artifacts/__init__.py index 0d50344dc0..81f87b39fd 100644 --- a/python/fate/components/core/component_desc/artifacts/__init__.py +++ b/python/fate/components/core/component_desc/artifacts/__init__.py @@ -29,6 +29,8 @@ model_directory_inputs, model_directory_output, model_directory_outputs, + model_unresolved_output, + model_unresolved_outputs, ) __all__ = [ @@ -56,4 +58,6 @@ "data_unresolved_outputs", "json_metric_output", "json_metric_outputs", + "model_unresolved_output", + "model_unresolved_outputs", ] diff --git a/python/fate/components/core/component_desc/artifacts/data/__init__.py b/python/fate/components/core/component_desc/artifacts/data/__init__.py index 713aa130c0..80c08408ca 100644 --- a/python/fate/components/core/component_desc/artifacts/data/__init__.py +++ b/python/fate/components/core/component_desc/artifacts/data/__init__.py @@ -62,9 +62,11 @@ def data_directory_outputs( ) -> Type[Iterator[DataDirectoryWriter]]: return _create_artifact_annotation(False, True, DataDirectoryArtifactDescribe, "data")(roles, desc, optional) + def data_unresolved_output(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[DataUnresolvedWriter]: return _create_artifact_annotation(False, False, DataUnresolvedArtifactDescribe, "data")(roles, desc, optional) + def data_unresolved_outputs( roles: Optional[List[Role]] = None, desc="", optional=False ) -> Type[Iterator[DataUnresolvedWriter]]: diff --git a/python/fate/components/core/component_desc/artifacts/model/__init__.py b/python/fate/components/core/component_desc/artifacts/model/__init__.py index f7ac839f39..da92a23ffe 100644 --- a/python/fate/components/core/component_desc/artifacts/model/__init__.py +++ b/python/fate/components/core/component_desc/artifacts/model/__init__.py @@ -1,12 +1,13 @@ from typing import Iterator, List, Optional, Type -from .._base_type import Role, _create_artifact_annotation from ._directory import ( ModelDirectoryArtifactDescribe, ModelDirectoryReader, ModelDirectoryWriter, ) from ._json import JsonModelArtifactDescribe, JsonModelReader, JsonModelWriter +from ._unresolved import ModelUnresolvedArtifactDescribe, ModelUnresolvedReader, ModelUnresolvedWriter +from .._base_type import Role, _create_artifact_annotation def json_model_input(roles: Optional[List[Role]] = None, desc="", optional=False) -> Type[JsonModelReader]: @@ -43,3 +44,15 @@ def model_directory_outputs( roles: Optional[List[Role]] = None, desc="", optional=False ) -> Type[Iterator[ModelDirectoryWriter]]: return _create_artifact_annotation(False, True, ModelDirectoryArtifactDescribe, "model")(roles, desc, optional) + + +def model_unresolved_output( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[ModelUnresolvedWriter]: + return _create_artifact_annotation(False, False, ModelUnresolvedArtifactDescribe, "model")(roles, desc, optional) + + +def model_unresolved_outputs( + roles: Optional[List[Role]] = None, desc="", optional=False +) -> Type[Iterator[ModelUnresolvedWriter]]: + return _create_artifact_annotation(False, True, ModelUnresolvedArtifactDescribe, "model")(roles, desc, optional) diff --git a/python/fate/components/core/component_desc/artifacts/model/_unresolved.py b/python/fate/components/core/component_desc/artifacts/model/_unresolved.py new file mode 100644 index 0000000000..bc9bfea9ec --- /dev/null +++ b/python/fate/components/core/component_desc/artifacts/model/_unresolved.py @@ -0,0 +1,36 @@ +from fate.components.core.essential import ModelUnresolvedArtifactType +from .._base_type import ( + URI, + ArtifactDescribe, + ModelOutputMetadata, + Metadata, + _ArtifactType, + _ArtifactTypeReader, + _ArtifactTypeWriter, +) + + +class ModelUnresolvedWriter(_ArtifactTypeWriter[ModelUnresolvedArtifactType]): + def write_metadata(self, metadata: dict, name=None, namespace=None): + self.artifact.metadata.metadata.update(metadata) + if name is not None: + self.artifact.metadata.name = name + if namespace is not None: + self.artifact.metadata.namespace = namespace + + +class ModelUnresolvedReader(_ArtifactTypeReader): + def get_metadata(self): + return self.artifact.metadata.metadata + + +class ModelUnresolvedArtifactDescribe(ArtifactDescribe[ModelUnresolvedArtifactType, ModelOutputMetadata]): + @classmethod + def get_type(cls): + return ModelUnresolvedArtifactType + + def get_writer(self, config, ctx, uri: URI, type_name: str) -> ModelUnresolvedWriter: + return ModelUnresolvedWriter(ctx, _ArtifactType(uri=uri, metadata=ModelOutputMetadata(), type_name=type_name)) + + def get_reader(self, ctx, uri: "URI", metadata: "Metadata", type_name: str) -> ModelUnresolvedReader: + return ModelUnresolvedReader(ctx, _ArtifactType(uri=uri, metadata=metadata, type_name=type_name)) diff --git a/python/fate/components/core/essential/__init__.py b/python/fate/components/core/essential/__init__.py index 3ae90da3b1..68b30ca087 100644 --- a/python/fate/components/core/essential/__init__.py +++ b/python/fate/components/core/essential/__init__.py @@ -6,6 +6,7 @@ JsonMetricArtifactType, JsonModelArtifactType, ModelDirectoryArtifactType, + ModelUnresolvedArtifactType, TableArtifactType, ) from ._label import Label diff --git a/python/fate/components/core/essential/_artifact_type.py b/python/fate/components/core/essential/_artifact_type.py index 28cada731a..eea7cbb108 100644 --- a/python/fate/components/core/essential/_artifact_type.py +++ b/python/fate/components/core/essential/_artifact_type.py @@ -47,3 +47,9 @@ class JsonMetricArtifactType(ArtifactType): type_name = "json_metric" path_type = "file" uri_types = ["file"] + + +class ModelUnresolvedArtifactType(ArtifactType): + type_name = "model_unresolved" + path_type = "unresolved" + uri_types = ["unresolved"] diff --git a/python/fate/components/core/spec/model.py b/python/fate/components/core/spec/model.py index e60cc28265..c74dc6e569 100644 --- a/python/fate/components/core/spec/model.py +++ b/python/fate/components/core/spec/model.py @@ -32,7 +32,6 @@ class MLModelPartiesSpec(pydantic.BaseModel): class MLModelFederatedSpec(pydantic.BaseModel): - task_id: str parties: MLModelPartiesSpec component: MLModelComponentSpec @@ -46,7 +45,6 @@ class MLModelModelSpec(pydantic.BaseModel): class MLModelPartySpec(pydantic.BaseModel): - party_task_id: str role: str partyid: str @@ -54,6 +52,5 @@ class MLModelPartySpec(pydantic.BaseModel): class MLModelSpec(pydantic.BaseModel): - federated: MLModelFederatedSpec party: MLModelPartySpec From 5db3b3ae2aa87318b1101ca30ed6cc5ce8646662 Mon Sep 17 00:00:00 2001 From: sagewe Date: Fri, 8 Dec 2023 23:03:41 +0800 Subject: [PATCH 19/42] add conditional federation init for deepspeed mode Signed-off-by: sagewe --- python/fate/components/core/__init__.py | 4 ++++ python/fate/components/core/_cpn_task_mode.py | 17 ++++++++++++++++ .../entrypoint/cli/component/cleanup_cli.py | 20 ++++++++++--------- .../entrypoint/cli/component/execute_cli.py | 7 ++++++- 4 files changed, 38 insertions(+), 10 deletions(-) create mode 100644 python/fate/components/core/_cpn_task_mode.py diff --git a/python/fate/components/core/__init__.py b/python/fate/components/core/__init__.py index d94395e763..5acf26ddd4 100644 --- a/python/fate/components/core/__init__.py +++ b/python/fate/components/core/__init__.py @@ -6,6 +6,7 @@ from ._load_metric_handler import load_metric_handler from .component_desc import Component, ComponentExecutionIO from .essential import ARBITER, GUEST, HOST, LOCAL, Label, Role, Stage +from ._cpn_task_mode import is_root_worker, is_deepspeed_mode, TaskMode __all__ = [ "Component", @@ -24,4 +25,7 @@ "HOST", "LOCAL", "Label", + "is_root_worker", + "is_deepspeed_mode", + "TaskMode", ] diff --git a/python/fate/components/core/_cpn_task_mode.py b/python/fate/components/core/_cpn_task_mode.py new file mode 100644 index 0000000000..b9856c50a0 --- /dev/null +++ b/python/fate/components/core/_cpn_task_mode.py @@ -0,0 +1,17 @@ +import enum +import os + + +class TaskMode(enum.StrEnum): + SIMPLE = "SIMPLE" + DEEPSPEED = "DEEPSPEED" + + +def is_deepspeed_mode(): + return os.getenv("FATE_TASK_TYPE", "").upper() == TaskMode.DEEPSPEED + + +def is_root_worker(): + if is_deepspeed_mode(): + return os.getenv("RANK", "0") == "0" + return True diff --git a/python/fate/components/entrypoint/cli/component/cleanup_cli.py b/python/fate/components/entrypoint/cli/component/cleanup_cli.py index 0a15bd890c..6c4e5792b3 100644 --- a/python/fate/components/entrypoint/cli/component/cleanup_cli.py +++ b/python/fate/components/entrypoint/cli/component/cleanup_cli.py @@ -16,6 +16,7 @@ def cleanup(process_tag, config, env_name): load_config_from_env, load_config_from_file, ) + from fate.components.core import is_root_worker configs = {} configs = load_config_from_env(configs, env_name) @@ -23,15 +24,16 @@ def cleanup(process_tag, config, env_name): config = TaskCleanupConfigSpec.parse_obj(configs) try: - print("start cleanup") - computing = load_computing(config.computing) - federation = load_federation(config.federation, computing) - ctx = Context( - computing=computing, - federation=federation, - ) - ctx.destroy() - print("cleanup done") + if is_root_worker(): + print("start cleanup") + computing = load_computing(config.computing) + federation = load_federation(config.federation, computing) + ctx = Context( + computing=computing, + federation=federation, + ) + ctx.destroy() + print("cleanup done") except Exception as e: traceback.print_exc() raise e diff --git a/python/fate/components/entrypoint/cli/component/execute_cli.py b/python/fate/components/entrypoint/cli/component/execute_cli.py index fa7e9a37b2..bbeb619977 100644 --- a/python/fate/components/entrypoint/cli/component/execute_cli.py +++ b/python/fate/components/entrypoint/cli/component/execute_cli.py @@ -98,6 +98,7 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path): load_device, load_federation, load_metric_handler, + is_root_worker, ) logger = logging.getLogger(__name__) @@ -106,7 +107,11 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path): party_task_id = config.party_task_id device = load_device(config.conf.device) computing = load_computing(config.conf.computing, config.conf.logger.config) - federation = load_federation(config.conf.federation, computing) + if is_root_worker(): + federation = load_federation(config.conf.federation, computing) + else: + federation = None + logger.info("skip federation initialization for non-root worker") cipher = CipherKit(device=device) ctx = Context( device=device, From a338350ac2463bd2ecd3e5785fa67b570312236a Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 11 Dec 2023 15:30:53 +0800 Subject: [PATCH 20/42] Fix bugs Signed-off-by: weijingchen Signed-off-by: cwj --- .../homo_lr/test_homo_lr_multi_ovr.py | 49 +++++++++++++++++++ examples/pipeline/homo_nn/test_nn_binary.py | 8 +-- fate_client | 2 +- fate_flow | 2 +- fate_test | 2 +- python/fate/arch/dataframe/ops/_sort.py | 2 +- python/fate/components/core/_cpn_task_mode.py | 5 +- python/fate/ml/nn/test/test_homo_nn_binary.py | 15 ++++-- python/fate/ml/nn/trainer/trainer_base.py | 3 +- 9 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py diff --git a/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py b/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py new file mode 100644 index 0000000000..eafa50c65b --- /dev/null +++ b/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py @@ -0,0 +1,49 @@ +import argparse +from fate_client.pipeline.components.fate import HomoLR, Evaluation +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + arbiter = parties.arbiter[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + + homo_lr_0 = HomoLR( + "homo_lr_0", + epochs=10, + batch_size=16, + ovr=True, + label_num=4 + ) + + homo_lr_0.guest.task_setting(train_data=DataWarehouseChannel(name="vehicle_scale_homo_guest", namespace="experiment")) + homo_lr_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name="vehicle_scale_homo_host", namespace="experiment")) + evaluation_0 = Evaluation( + 'eval_0', + default_eval_setting='multi', + input_data=[homo_lr_0.outputs['train_output_data']] + ) + + + pipeline.add_task(homo_lr_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + pipeline.fit() + print (pipeline.get_task_info("homo_lr_0").get_output_data()) + print(pipeline.get_task_info("homo_lr_0").get_output_model()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) \ No newline at end of file diff --git a/examples/pipeline/homo_nn/test_nn_binary.py b/examples/pipeline/homo_nn/test_nn_binary.py index cfa6c59a5a..fd3e7bca40 100644 --- a/examples/pipeline/homo_nn/test_nn_binary.py +++ b/examples/pipeline/homo_nn/test_nn_binary.py @@ -34,7 +34,7 @@ def main(config="../../config.yaml", namespace=""): host = parties.host[0] arbiter = parties.arbiter[0] - epochs = 10 + epochs = 5 batch_size = 64 in_feat = 30 out_feat = 16 @@ -54,7 +54,7 @@ def main(config="../../config.yaml", namespace=""): ), loss=nn.BCELoss(), optimizer=optim.Adam(lr=lr), - training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size, seed=114514, logging_strategy='steps'), + training_args=TrainingArguments(num_train_epochs=epochs, per_device_train_batch_size=batch_size), fed_args=FedAVGArguments(), task_type='binary' ) @@ -73,8 +73,8 @@ def main(config="../../config.yaml", namespace=""): predict_model_input=homo_nn_0.outputs['train_model_output'] ) - homo_nn_1.guest.component_setting(test_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) - homo_nn_1.hosts[0].component_setting(test_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) + homo_nn_1.guest.task_setting(test_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) + homo_nn_1.hosts[0].task_setting(test_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) evaluation_0 = Evaluation( 'eval_0', diff --git a/fate_client b/fate_client index 79001231f9..943969ab27 160000 --- a/fate_client +++ b/fate_client @@ -1 +1 @@ -Subproject commit 79001231f9c6485898cafdf7845b305673cb68af +Subproject commit 943969ab27dfe0188ab6e6837e1ea99df8215d1e diff --git a/fate_flow b/fate_flow index b8cf2dad36..a8c93ebb43 160000 --- a/fate_flow +++ b/fate_flow @@ -1 +1 @@ -Subproject commit b8cf2dad361a8428648bd60c980ad3ae369ef5e2 +Subproject commit a8c93ebb4309643b72b58df12b85659ba828a845 diff --git a/fate_test b/fate_test index bc4a603b4b..6b0ef252da 160000 --- a/fate_test +++ b/fate_test @@ -1 +1 @@ -Subproject commit bc4a603b4bbb4dce6c528b5116acbc30fcaf335a +Subproject commit 6b0ef252dafe60d4bc04ad90d352108af883c684 diff --git a/python/fate/arch/dataframe/ops/_sort.py b/python/fate/arch/dataframe/ops/_sort.py index ab77ac9edf..e51c3f2428 100644 --- a/python/fate/arch/dataframe/ops/_sort.py +++ b/python/fate/arch/dataframe/ops/_sort.py @@ -118,7 +118,7 @@ def _extract_columns(r_id): block_table = df._ctx.computing.parallelize( blocks_with_id, include_key=True, - partition=df.block_table.partitions + partition=df.block_table.num_partitions ) partition_order_mappings = get_partition_order_mappings_by_block_table(block_table, block_row_size=block_row_size) diff --git a/python/fate/components/core/_cpn_task_mode.py b/python/fate/components/core/_cpn_task_mode.py index b9856c50a0..8c007e0e06 100644 --- a/python/fate/components/core/_cpn_task_mode.py +++ b/python/fate/components/core/_cpn_task_mode.py @@ -1,11 +1,12 @@ import enum import os - -class TaskMode(enum.StrEnum): +class TaskMode(enum.Enum): SIMPLE = "SIMPLE" DEEPSPEED = "DEEPSPEED" + def __str__(self): + return self.value def is_deepspeed_mode(): return os.getenv("FATE_TASK_TYPE", "").upper() == TaskMode.DEEPSPEED diff --git a/python/fate/ml/nn/test/test_homo_nn_binary.py b/python/fate/ml/nn/test/test_homo_nn_binary.py index 42cff7ee1a..94385c3b8b 100644 --- a/python/fate/ml/nn/test/test_homo_nn_binary.py +++ b/python/fate/ml/nn/test/test_homo_nn_binary.py @@ -17,16 +17,16 @@ def create_ctx(local): import logging logger = logging.getLogger() - logger.setLevel(logging.DEBUG) + logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() - console_handler.setLevel(logging.DEBUG) + console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) - computing = CSession() + computing = CSession(data_dir='./session_dir') return Context( computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]) ) @@ -37,7 +37,10 @@ def create_ctx(local): model = t.nn.Sequential(t.nn.Linear(30, 1), t.nn.Sigmoid()) ds = TableDataset(return_dict=False, to_tensor=True) - ds.load('./../../../../../examples/data/breast_homo_guest.py') + ds.load('./../../../../../examples/data/breast_homo_guest.csv') + + ds_val = TableDataset(return_dict=False, to_tensor=True) + ds_val.load('./../../../../../examples/data/breast_homo_test.csv') if sys.argv[1] == "guest": ctx = create_ctx(guest) @@ -52,7 +55,8 @@ def create_ctx(local): training_args=args, loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), - train_set=ds + train_set=ds, + val_set=ds_val ) trainer.train() @@ -68,6 +72,7 @@ def create_ctx(local): loss_fn=t.nn.BCELoss(), optimizer=t.optim.SGD(model.parameters(), lr=0.01), train_set=ds, + val_set=ds_val ) trainer.train() diff --git a/python/fate/ml/nn/trainer/trainer_base.py b/python/fate/ml/nn/trainer/trainer_base.py index 3241f9a51a..9ef4867d2e 100644 --- a/python/fate/ml/nn/trainer/trainer_base.py +++ b/python/fate/ml/nn/trainer/trainer_base.py @@ -637,12 +637,13 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: self.wrapped_trainer.aggregator = self.wrapped_trainer.init_aggregator(self.ctx, self.fed_arg) def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.wrapped_trainer.local_mode: return # aggregate loss if self.fed_arg.aggregate_strategy == AggregateStrategy.EPOCH.value: if self.wrapped_trainer.aggregation_checker.can_aggregate_loss: - if 'train_loss' in state.log_history[-1]: # final log is ignored + if 'loss' not in state.log_history[-1]: # only process train loss return loss = state.log_history[-1]["loss"] agg_round = self.wrapped_trainer.aggregation_checker.loss_aggregation_count From 27dfe97202b03dea30a3eda477715c08da5d6205 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Mon, 11 Dec 2023 16:48:44 +0800 Subject: [PATCH 21/42] use unique artifact name(#4668) Signed-off-by: Yu Wu --- python/fate/components/components/union.py | 3 ++- python/fate/components/core/spec/artifact.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/fate/components/components/union.py b/python/fate/components/components/union.py index 25a0d5361f..334cd6457a 100644 --- a/python/fate/components/components/union.py +++ b/python/fate/components/components/union.py @@ -28,7 +28,8 @@ def union( data_list = [] data_len_dict = {} for data in input_data_list: - data_name = f"{data.artifact.metadata.source.task_name}.{data.artifact.metadata.source.output_artifact_key}" + data_name = f"{data.artifact.metadata.source.task_name}.{data.artifact.metadata.source.unique_key()}" + # logger.debug(f"data_name: {data_name}") data = data.read() data_list.append(data) data_len_dict[data_name] = data.shape[0] diff --git a/python/fate/components/core/spec/artifact.py b/python/fate/components/core/spec/artifact.py index 426b2c0643..44ec5d32c8 100644 --- a/python/fate/components/core/spec/artifact.py +++ b/python/fate/components/core/spec/artifact.py @@ -12,18 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime import re from typing import Dict, List, Optional, Union import pydantic from .model import ( - MLModelComponentSpec, - MLModelFederatedSpec, - MLModelModelSpec, - MLModelPartiesSpec, - MLModelPartySpec, MLModelSpec, ) @@ -49,6 +43,12 @@ class ArtifactSource(pydantic.BaseModel): output_artifact_key: str output_index: Optional[int] = None + def unique_key(self): + key = f"{self.task_id}_{self.task_name}_{self.output_artifact_key}" + if self.output_index is not None: + key = f"{key}_index_{self.output_index}" + return key + class Metadata(pydantic.BaseModel): metadata: dict = pydantic.Field(default_factory=dict) From 183f60b30c06cc6dc9ea89750eee4581d09b2c4f Mon Sep 17 00:00:00 2001 From: sagewe Date: Sat, 9 Dec 2023 17:16:31 +0800 Subject: [PATCH 22/42] refactor federation Signed-off-by: sagewe --- python/fate/arch/context/_context.py | 4 +- python/fate/arch/context/_federation.py | 21 +- python/fate/arch/federation/__init__.py | 4 +- python/fate/arch/federation/_builder.py | 10 +- python/fate/arch/federation/api/__init__.py | 2 + .../{federation.py => api/_federation.py} | 0 python/fate/arch/federation/{ => api}/_gc.py | 10 +- .../fate/arch/federation/{ => api}/_type.py | 0 .../fate/arch/federation/backends/__init__.py | 0 .../{ => backends}/eggroll/__init__.py | 0 .../{ => backends}/eggroll/_federation.py | 5 +- .../federation/{ => backends}/osx/__init__.py | 0 .../{ => backends}/osx/_federation.py | 406 +++++++++--------- .../{ => backends}/osx/_mq_channel.py | 5 +- .../federation/{ => backends}/osx/osx_pb2.py | 0 .../{ => backends}/osx/osx_pb2_grpc.py | 0 .../{ => backends}/pulsar/__init__.py | 0 .../{ => backends}/pulsar/_federation.py | 5 +- .../{ => backends}/pulsar/_mq_channel.py | 2 +- .../{ => backends}/pulsar/_pulsar_manager.py | 0 .../{ => backends}/rabbitmq/__init__.py | 0 .../{ => backends}/rabbitmq/_federation.py | 6 +- .../{ => backends}/rabbitmq/_mq_channel.py | 2 +- .../rabbitmq/_rabbit_manager.py | 0 .../{ => backends}/standalone/__init__.py | 0 .../{ => backends}/standalone/_federation.py | 7 +- .../arch/federation/message_queue/__init__.py | 4 + .../{ => message_queue}/_datastream.py | 0 .../{ => message_queue}/_federation.py | 106 ++--- .../federation/{ => message_queue}/_nretry.py | 0 .../{ => message_queue}/_parties.py | 0 python/fate/arch/utils/trace.py | 2 +- 32 files changed, 267 insertions(+), 334 deletions(-) create mode 100644 python/fate/arch/federation/api/__init__.py rename python/fate/arch/federation/{federation.py => api/_federation.py} (100%) rename python/fate/arch/federation/{ => api}/_gc.py (91%) rename python/fate/arch/federation/{ => api}/_type.py (100%) create mode 100644 python/fate/arch/federation/backends/__init__.py rename python/fate/arch/federation/{ => backends}/eggroll/__init__.py (100%) rename python/fate/arch/federation/{ => backends}/eggroll/_federation.py (97%) rename python/fate/arch/federation/{ => backends}/osx/__init__.py (100%) rename python/fate/arch/federation/{ => backends}/osx/_federation.py (95%) rename python/fate/arch/federation/{ => backends}/osx/_mq_channel.py (98%) rename python/fate/arch/federation/{ => backends}/osx/osx_pb2.py (100%) rename python/fate/arch/federation/{ => backends}/osx/osx_pb2_grpc.py (100%) rename python/fate/arch/federation/{ => backends}/pulsar/__init__.py (100%) rename python/fate/arch/federation/{ => backends}/pulsar/_federation.py (99%) rename python/fate/arch/federation/{ => backends}/pulsar/_mq_channel.py (99%) rename python/fate/arch/federation/{ => backends}/pulsar/_pulsar_manager.py (100%) rename python/fate/arch/federation/{ => backends}/rabbitmq/__init__.py (100%) rename python/fate/arch/federation/{ => backends}/rabbitmq/_federation.py (98%) rename python/fate/arch/federation/{ => backends}/rabbitmq/_mq_channel.py (98%) rename python/fate/arch/federation/{ => backends}/rabbitmq/_rabbit_manager.py (100%) rename python/fate/arch/federation/{ => backends}/standalone/__init__.py (100%) rename python/fate/arch/federation/{ => backends}/standalone/_federation.py (93%) create mode 100644 python/fate/arch/federation/message_queue/__init__.py rename python/fate/arch/federation/{ => message_queue}/_datastream.py (100%) rename python/fate/arch/federation/{ => message_queue}/_federation.py (86%) rename python/fate/arch/federation/{ => message_queue}/_nretry.py (100%) rename python/fate/arch/federation/{ => message_queue}/_parties.py (100%) diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 644dc78073..911616242a 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -30,8 +30,8 @@ T = TypeVar("T") if typing.TYPE_CHECKING: - from ..federation.federation import Federation - from ..computing.table import KVTableContext + from fate.arch.federation.api import Federation + from fate.arch.computing.table import KVTableContext class Context: diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index c2e3189859..6fb89511c1 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import pickle import logging +import pickle import struct import typing from typing import Any, List, Tuple, TypeVar, Union @@ -22,30 +22,13 @@ from fate.arch.abc import PartyMeta from ._namespace import NS from ..computing import is_table -from ..federation._gc import IterationGC logger = logging.getLogger(__name__) T = TypeVar("T") if typing.TYPE_CHECKING: from fate.arch.context import Context - from fate.arch.federation.federation import Federation - - -class GC: - def __init__(self) -> None: - self._push_gc_dict = {} - self._pull_gc_dict = {} - - def get_or_set_push_gc(self, key): - if key not in self._push_gc_dict: - self._push_gc_dict[key] = IterationGC() - return self._push_gc_dict[key] - - def get_or_set_pull_gc(self, key): - if key not in self._pull_gc_dict: - self._pull_gc_dict[key] = IterationGC() - return self._pull_gc_dict[key] + from fate.arch.federation.api import Federation class _KeyedParty: diff --git a/python/fate/arch/federation/__init__.py b/python/fate/arch/federation/__init__.py index 4aa278575d..c4f611d98b 100644 --- a/python/fate/arch/federation/__init__.py +++ b/python/fate/arch/federation/__init__.py @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from ._builder import FederationBuilder, FederationMode -from ._type import FederationDataType +from .api import Federation, FederationDataType -__all__ = ["FederationDataType", "FederationBuilder", "FederationMode"] +__all__ = ["Federation", "FederationDataType", "FederationBuilder", "FederationMode"] diff --git a/python/fate/arch/federation/_builder.py b/python/fate/arch/federation/_builder.py index 9afdce68d3..ab14f0a2eb 100644 --- a/python/fate/arch/federation/_builder.py +++ b/python/fate/arch/federation/_builder.py @@ -84,7 +84,7 @@ def build(self, computing_session, t: FederationType, conf: dict): raise ValueError(f"{t} not in {FederationType}") def build_standalone(self, computing_session): - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.federation.backends.standalone import StandaloneFederation return StandaloneFederation( standalone_session=computing_session, @@ -99,7 +99,7 @@ def build_osx( if options is None: options = {} if mode == FederationMode.MESSAGE_QUEUE: - from fate.arch.federation.osx import OSXFederation + from fate.arch.federation.backends.osx import OSXFederation return OSXFederation.from_conf( federation_session_id=self._federation_id, @@ -112,7 +112,7 @@ def build_osx( ) else: from fate.arch.computing.eggroll import CSession - from fate.arch.federation.eggroll import EggrollFederation + from fate.arch.federation.backends.eggroll import EggrollFederation if not isinstance(computing_session, CSession): raise RuntimeError( @@ -128,7 +128,7 @@ def build_osx( ) def build_rabbitmq(self, computing_session, host: str, port: int, options: dict): - from fate.arch.federation.rabbitmq import RabbitmqFederation + from fate.arch.federation.backends.rabbitmq import RabbitmqFederation return RabbitmqFederation.from_conf( federation_session_id=self._federation_id, @@ -148,7 +148,7 @@ def build_rabbitmq(self, computing_session, host: str, port: int, options: dict) ) def build_pulsar(self, computing_session, host: str, port: int, options: dict): - from fate.arch.federation.pulsar import PulsarFederation + from fate.arch.federation.backends.pulsar import PulsarFederation return PulsarFederation.from_conf( federation_session_id=self._federation_id, diff --git a/python/fate/arch/federation/api/__init__.py b/python/fate/arch/federation/api/__init__.py new file mode 100644 index 0000000000..12191de06d --- /dev/null +++ b/python/fate/arch/federation/api/__init__.py @@ -0,0 +1,2 @@ +from ._federation import Federation +from ._type import FederationDataType diff --git a/python/fate/arch/federation/federation.py b/python/fate/arch/federation/api/_federation.py similarity index 100% rename from python/fate/arch/federation/federation.py rename to python/fate/arch/federation/api/_federation.py diff --git a/python/fate/arch/federation/_gc.py b/python/fate/arch/federation/api/_gc.py similarity index 91% rename from python/fate/arch/federation/_gc.py rename to python/fate/arch/federation/api/_gc.py index 847be7ae2e..e5bdc7f662 100644 --- a/python/fate/arch/federation/_gc.py +++ b/python/fate/arch/federation/api/_gc.py @@ -19,7 +19,7 @@ from collections import deque -LOGGER = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class GarbageCollector: @@ -48,10 +48,10 @@ def clean(self, name: str, tag: str): @classmethod def _safe_gc_call(cls, obj, method: str, kwargs: dict): try: - LOGGER.debug(f"[CLEAN]deleting {obj}, {method}, {kwargs}") + logger.debug(f"[CLEAN]deleting {obj}, {method}, {kwargs}") getattr(obj, method)(**kwargs) except Exception as e: - LOGGER.debug(f"[CLEAN]this could be ignore {e}") + logger.debug(f"[CLEAN]this could be ignore {e}") class IterationGC: @@ -89,7 +89,7 @@ def clean(self): def _safe_gc_call(actions: typing.List[typing.Tuple[typing.Any, str, dict]]): for obj, method, args_dict in actions: try: - LOGGER.debug(f"[CLEAN]deleting {obj}, {method}, {args_dict}") + logger.debug(f"[CLEAN]deleting {obj}, {method}, {args_dict}") getattr(obj, method)(**args_dict) except Exception as e: - LOGGER.debug(f"[CLEAN]this could be ignore {e}") + logger.debug(f"[CLEAN]this could be ignore {e}") diff --git a/python/fate/arch/federation/_type.py b/python/fate/arch/federation/api/_type.py similarity index 100% rename from python/fate/arch/federation/_type.py rename to python/fate/arch/federation/api/_type.py diff --git a/python/fate/arch/federation/backends/__init__.py b/python/fate/arch/federation/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/arch/federation/eggroll/__init__.py b/python/fate/arch/federation/backends/eggroll/__init__.py similarity index 100% rename from python/fate/arch/federation/eggroll/__init__.py rename to python/fate/arch/federation/backends/eggroll/__init__.py diff --git a/python/fate/arch/federation/eggroll/_federation.py b/python/fate/arch/federation/backends/eggroll/_federation.py similarity index 97% rename from python/fate/arch/federation/eggroll/_federation.py rename to python/fate/arch/federation/backends/eggroll/_federation.py index f7eabe0098..155db04de5 100644 --- a/python/fate/arch/federation/eggroll/_federation.py +++ b/python/fate/arch/federation/backends/eggroll/_federation.py @@ -23,8 +23,9 @@ from eggroll.computing import RollPair from eggroll.federation import RollSiteContext -from fate.arch.federation.federation import Federation, PartyMeta -from ...computing.eggroll import Table +from fate.arch.federation.api import Federation +from fate.arch.abc import PartyMeta +from fate.arch.computing.eggroll import Table logger = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/osx/__init__.py b/python/fate/arch/federation/backends/osx/__init__.py similarity index 100% rename from python/fate/arch/federation/osx/__init__.py rename to python/fate/arch/federation/backends/osx/__init__.py diff --git a/python/fate/arch/federation/osx/_federation.py b/python/fate/arch/federation/backends/osx/_federation.py similarity index 95% rename from python/fate/arch/federation/osx/_federation.py rename to python/fate/arch/federation/backends/osx/_federation.py index 5ba0bbf6c2..72f270112a 100644 --- a/python/fate/arch/federation/osx/_federation.py +++ b/python/fate/arch/federation/backends/osx/_federation.py @@ -1,203 +1,203 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import typing -from logging import getLogger - -from fate.arch.abc import PartyMeta -from fate.arch.federation.osx import osx_pb2 -from ._mq_channel import MQChannel -from .._federation import FederationBase - -LOGGER = getLogger(__name__) -# default message max size in bytes = 1MB -DEFAULT_MESSAGE_MAX_SIZE = 1048576 - - -class MQ(object): - def __init__(self, host, port): - self.host = host - self.port = port - - def __str__(self): - return f"MQ(host={self.host}, port={self.port}, type=osx)" - - def __repr__(self): - return self.__str__() - - -class _TopicPair(object): - def __init__(self, namespace, send, receive): - self.namespace = namespace - self.send = send - self.receive = receive - - def __str__(self) -> str: - return f"<_TopicPair namespace={self.namespace}, send={self.send}, receive={self.receive}>" - - -class OSXFederation(FederationBase): - @staticmethod - def from_conf( - federation_session_id: str, - computing_session, - party: PartyMeta, - parties: typing.List[PartyMeta], - host: str, - port: int, - max_message_size: typing.Optional[int] = None, - ): - if max_message_size is None: - max_message_size = DEFAULT_MESSAGE_MAX_SIZE - mq = MQ(host, port) - - return OSXFederation( - federation_session_id=federation_session_id, - computing_session=computing_session, - party=party, - parties=parties, - max_message_size=max_message_size, - mq=mq, - ) - - def __init__( - self, - federation_session_id, - computing_session, - party: PartyMeta, - parties: typing.List[PartyMeta], - max_message_size, - mq, - ): - super().__init__( - session_id=federation_session_id, - computing_session=computing_session, - party=party, - parties=parties, - max_message_size=max_message_size, - mq=mq, - ) - - def __getstate__(self): - pass - - def destroy(self): - LOGGER.debug("start to cleanup...") - - channel = MQChannel( - host=self._mq.host, - port=self._mq.port, - namespace=self._session_id, - send_topic=None, - receive_topic=None, - src_party_id=None, - src_role=None, - dst_party_id=None, - dst_role=None, - ) - - channel.cleanup() - channel.close() - - def _maybe_create_topic_and_replication(self, party, topic_suffix): - LOGGER.debug(f"_maybe_create_topic_and_replication, party={party}, topic_suffix={topic_suffix}") - send_topic_name = f"{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}" - receive_topic_name = f"{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}" - - # topic_pair is a pair of topic for sending and receiving message respectively - topic_pair = _TopicPair( - namespace=self._session_id, - send=send_topic_name, - receive=receive_topic_name, - ) - return topic_pair - - def _get_channel( - self, topic_pair: _TopicPair, src_party_id, src_role, dst_party_id, dst_role, mq: MQ, conf: dict = None - ): - LOGGER.debug( - f"_get_channel, topic_pari={topic_pair}, src_party_id={src_party_id}, src_role={src_role}, dst_party_id={dst_party_id}, dst_role={dst_role}" - ) - return MQChannel( - host=mq.host, - port=mq.port, - namespace=topic_pair.namespace, - send_topic=topic_pair.send, - receive_topic=topic_pair.receive, - src_party_id=src_party_id, - src_role=src_role, - dst_party_id=dst_party_id, - dst_role=dst_role, - ) - - _topic_ip_map = {} - - # @nretry - # def _query_receive_topic(self, channel_info): - # # LOGGER.debug(f"_query_receive_topic, channel_info={channel_info}") - # # topic = channel_info._receive_topic - # # if topic not in self._topic_ip_map: - # # LOGGER.info(f"query topic {topic} miss cache ") - # # response = channel_info.query() - # # if response.code == "0": - # # topic_info = osx_pb2.TopicInfo() - # # topic_info.ParseFromString(response.payload) - # # self._topic_ip_map[topic] = (topic_info.ip, topic_info.port) - # # LOGGER.info(f"query result {topic} {topic_info}") - # # else: - # # raise LookupError(f"{response}") - # # host, port = self._topic_ip_map[topic] - # # - # # new_channel_info = channel_info - # # if channel_info._host != host or channel_info._port != port: - # # LOGGER.info( - # # f"channel info missmatch, host: {channel_info._host} vs {host} and port: {channel_info._port} vs {port}" - # # ) - # # new_channel_info = MQChannel( - # # host=host, - # # port=port, - # # namespace=channel_info._namespace, - # # send_topic=channel_info._send_topic, - # # receive_topic=channel_info._receive_topic, - # # src_party_id=channel_info._src_party_id, - # # src_role=channel_info._src_role, - # # dst_party_id=channel_info._dst_party_id, - # # dst_role=channel_info._dst_role, - # # ) - # # return new_channel_info - # return channel_info; - - def _get_consume_message(self, channel_info): - LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}") - while True: - response = channel_info.consume() - # LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}, response={response}") - if response.code == "E0000000601": - raise LookupError(f"{response}") - message = osx_pb2.Message() - message.ParseFromString(response.payload) - # offset = response.metadata["MessageOffSet"] - head_str = str(message.head, encoding="utf-8") - # LOGGER.debug(f"head str {head_str}") - properties = json.loads(head_str) - # LOGGER.debug(f"osx response properties {properties}") - body = message.body - yield 0, properties, body - - def _consume_ack(self, channel_info, id): - return - # LOGGER.debug(f"_comsume_ack, channel_info={channel_info}, id={id}") - # channel_info.ack(offset=id) +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import typing +from logging import getLogger + +from fate.arch.abc import PartyMeta +from fate.arch.federation.backends.osx import osx_pb2 +from fate.arch.federation.message_queue import MessageQueueBasedFederation +from ._mq_channel import MQChannel + +LOGGER = getLogger(__name__) +# default message max size in bytes = 1MB +DEFAULT_MESSAGE_MAX_SIZE = 1048576 + + +class MQ(object): + def __init__(self, host, port): + self.host = host + self.port = port + + def __str__(self): + return f"MQ(host={self.host}, port={self.port}, type=osx)" + + def __repr__(self): + return self.__str__() + + +class _TopicPair(object): + def __init__(self, namespace, send, receive): + self.namespace = namespace + self.send = send + self.receive = receive + + def __str__(self) -> str: + return f"<_TopicPair namespace={self.namespace}, send={self.send}, receive={self.receive}>" + + +class OSXFederation(MessageQueueBasedFederation): + @staticmethod + def from_conf( + federation_session_id: str, + computing_session, + party: PartyMeta, + parties: typing.List[PartyMeta], + host: str, + port: int, + max_message_size: typing.Optional[int] = None, + ): + if max_message_size is None: + max_message_size = DEFAULT_MESSAGE_MAX_SIZE + mq = MQ(host, port) + + return OSXFederation( + federation_session_id=federation_session_id, + computing_session=computing_session, + party=party, + parties=parties, + max_message_size=max_message_size, + mq=mq, + ) + + def __init__( + self, + federation_session_id, + computing_session, + party: PartyMeta, + parties: typing.List[PartyMeta], + max_message_size, + mq, + ): + super().__init__( + session_id=federation_session_id, + computing_session=computing_session, + party=party, + parties=parties, + max_message_size=max_message_size, + mq=mq, + ) + + def __getstate__(self): + pass + + def destroy(self): + LOGGER.debug("start to cleanup...") + + channel = MQChannel( + host=self._mq.host, + port=self._mq.port, + namespace=self._session_id, + send_topic=None, + receive_topic=None, + src_party_id=None, + src_role=None, + dst_party_id=None, + dst_role=None, + ) + + channel.cleanup() + channel.close() + + def _maybe_create_topic_and_replication(self, party, topic_suffix): + LOGGER.debug(f"_maybe_create_topic_and_replication, party={party}, topic_suffix={topic_suffix}") + send_topic_name = f"{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}" + receive_topic_name = f"{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}" + + # topic_pair is a pair of topic for sending and receiving message respectively + topic_pair = _TopicPair( + namespace=self._session_id, + send=send_topic_name, + receive=receive_topic_name, + ) + return topic_pair + + def _get_channel( + self, topic_pair: _TopicPair, src_party_id, src_role, dst_party_id, dst_role, mq: MQ, conf: dict = None + ): + LOGGER.debug( + f"_get_channel, topic_pari={topic_pair}, src_party_id={src_party_id}, src_role={src_role}, dst_party_id={dst_party_id}, dst_role={dst_role}" + ) + return MQChannel( + host=mq.host, + port=mq.port, + namespace=topic_pair.namespace, + send_topic=topic_pair.send, + receive_topic=topic_pair.receive, + src_party_id=src_party_id, + src_role=src_role, + dst_party_id=dst_party_id, + dst_role=dst_role, + ) + + _topic_ip_map = {} + + # @nretry + # def _query_receive_topic(self, channel_info): + # # LOGGER.debug(f"_query_receive_topic, channel_info={channel_info}") + # # topic = channel_info._receive_topic + # # if topic not in self._topic_ip_map: + # # LOGGER.info(f"query topic {topic} miss cache ") + # # response = channel_info.query() + # # if response.code == "0": + # # topic_info = osx_pb2.TopicInfo() + # # topic_info.ParseFromString(response.payload) + # # self._topic_ip_map[topic] = (topic_info.ip, topic_info.port) + # # LOGGER.info(f"query result {topic} {topic_info}") + # # else: + # # raise LookupError(f"{response}") + # # host, port = self._topic_ip_map[topic] + # # + # # new_channel_info = channel_info + # # if channel_info._host != host or channel_info._port != port: + # # LOGGER.info( + # # f"channel info missmatch, host: {channel_info._host} vs {host} and port: {channel_info._port} vs {port}" + # # ) + # # new_channel_info = MQChannel( + # # host=host, + # # port=port, + # # namespace=channel_info._namespace, + # # send_topic=channel_info._send_topic, + # # receive_topic=channel_info._receive_topic, + # # src_party_id=channel_info._src_party_id, + # # src_role=channel_info._src_role, + # # dst_party_id=channel_info._dst_party_id, + # # dst_role=channel_info._dst_role, + # # ) + # # return new_channel_info + # return channel_info; + + def _get_consume_message(self, channel_info): + LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}") + while True: + response = channel_info.consume() + # LOGGER.debug(f"_get_comsume_message, channel_info={channel_info}, response={response}") + if response.code == "E0000000601": + raise LookupError(f"{response}") + message = osx_pb2.Message() + message.ParseFromString(response.payload) + # offset = response.metadata["MessageOffSet"] + head_str = str(message.head, encoding="utf-8") + # LOGGER.debug(f"head str {head_str}") + properties = json.loads(head_str) + # LOGGER.debug(f"osx response properties {properties}") + body = message.body + yield 0, properties, body + + def _consume_ack(self, channel_info, id): + return + # LOGGER.debug(f"_comsume_ack, channel_info={channel_info}, id={id}") + # channel_info.ack(offset=id) diff --git a/python/fate/arch/federation/osx/_mq_channel.py b/python/fate/arch/federation/backends/osx/_mq_channel.py similarity index 98% rename from python/fate/arch/federation/osx/_mq_channel.py rename to python/fate/arch/federation/backends/osx/_mq_channel.py index 0c359d6215..0ecaf6b80d 100644 --- a/python/fate/arch/federation/osx/_mq_channel.py +++ b/python/fate/arch/federation/backends/osx/_mq_channel.py @@ -20,10 +20,9 @@ from typing import Dict, List, Any import time import grpc -from fate.arch.federation.osx import osx_pb2 -from fate.arch.federation.osx.osx_pb2_grpc import PrivateTransferTransportStub +from fate.arch.federation.backends.osx import osx_pb2 +from fate.arch.federation.backends.osx.osx_pb2_grpc import PrivateTransferTransportStub import numpy as np -# from .._nretry import nretry LOGGER = getLogger(__name__) diff --git a/python/fate/arch/federation/osx/osx_pb2.py b/python/fate/arch/federation/backends/osx/osx_pb2.py similarity index 100% rename from python/fate/arch/federation/osx/osx_pb2.py rename to python/fate/arch/federation/backends/osx/osx_pb2.py diff --git a/python/fate/arch/federation/osx/osx_pb2_grpc.py b/python/fate/arch/federation/backends/osx/osx_pb2_grpc.py similarity index 100% rename from python/fate/arch/federation/osx/osx_pb2_grpc.py rename to python/fate/arch/federation/backends/osx/osx_pb2_grpc.py diff --git a/python/fate/arch/federation/pulsar/__init__.py b/python/fate/arch/federation/backends/pulsar/__init__.py similarity index 100% rename from python/fate/arch/federation/pulsar/__init__.py rename to python/fate/arch/federation/backends/pulsar/__init__.py diff --git a/python/fate/arch/federation/pulsar/_federation.py b/python/fate/arch/federation/backends/pulsar/_federation.py similarity index 99% rename from python/fate/arch/federation/pulsar/_federation.py rename to python/fate/arch/federation/backends/pulsar/_federation.py index e5c27b033f..b20a1a663a 100644 --- a/python/fate/arch/federation/pulsar/_federation.py +++ b/python/fate/arch/federation/backends/pulsar/_federation.py @@ -18,8 +18,7 @@ from typing import List, Optional from fate.arch.abc import PartyMeta - -from .._federation import FederationBase +from fate.arch.federation.message_queue import MessageQueueBasedFederation from ._mq_channel import ( DEFAULT_CLUSTER, DEFAULT_SUBSCRIPTION_NAME, @@ -54,7 +53,7 @@ def __init__(self, tenant, namespace, send, receive): self.receive = receive -class PulsarFederation(FederationBase): +class PulsarFederation(MessageQueueBasedFederation): @staticmethod def from_conf( federation_session_id: str, diff --git a/python/fate/arch/federation/pulsar/_mq_channel.py b/python/fate/arch/federation/backends/pulsar/_mq_channel.py similarity index 99% rename from python/fate/arch/federation/pulsar/_mq_channel.py rename to python/fate/arch/federation/backends/pulsar/_mq_channel.py index 6b8061b993..687eb08a52 100644 --- a/python/fate/arch/federation/pulsar/_mq_channel.py +++ b/python/fate/arch/federation/backends/pulsar/_mq_channel.py @@ -19,7 +19,7 @@ import pulsar -from .._nretry import nretry +from fate.arch.federation.message_queue import nretry LOGGER = logging.getLogger(__name__) CHANNEL_TYPE_PRODUCER = "producer" diff --git a/python/fate/arch/federation/pulsar/_pulsar_manager.py b/python/fate/arch/federation/backends/pulsar/_pulsar_manager.py similarity index 100% rename from python/fate/arch/federation/pulsar/_pulsar_manager.py rename to python/fate/arch/federation/backends/pulsar/_pulsar_manager.py diff --git a/python/fate/arch/federation/rabbitmq/__init__.py b/python/fate/arch/federation/backends/rabbitmq/__init__.py similarity index 100% rename from python/fate/arch/federation/rabbitmq/__init__.py rename to python/fate/arch/federation/backends/rabbitmq/__init__.py diff --git a/python/fate/arch/federation/rabbitmq/_federation.py b/python/fate/arch/federation/backends/rabbitmq/_federation.py similarity index 98% rename from python/fate/arch/federation/rabbitmq/_federation.py rename to python/fate/arch/federation/backends/rabbitmq/_federation.py index 5222b8ac56..219b31ab18 100644 --- a/python/fate/arch/federation/rabbitmq/_federation.py +++ b/python/fate/arch/federation/backends/rabbitmq/_federation.py @@ -19,9 +19,7 @@ from typing import List, Optional from fate.arch.abc import PartyMeta - -from .._federation import FederationBase -from .._parties import Party +from fate.arch.federation.message_queue import MessageQueueBasedFederation, Party from ._mq_channel import MQChannel from ._rabbit_manager import RabbitManager @@ -58,7 +56,7 @@ def __init__(self, tenant=None, namespace=None, vhost=None, send=None, receive=N self.receive = receive -class RabbitmqFederation(FederationBase): +class RabbitmqFederation(MessageQueueBasedFederation): @staticmethod def from_conf( federation_session_id: str, diff --git a/python/fate/arch/federation/rabbitmq/_mq_channel.py b/python/fate/arch/federation/backends/rabbitmq/_mq_channel.py similarity index 98% rename from python/fate/arch/federation/rabbitmq/_mq_channel.py rename to python/fate/arch/federation/backends/rabbitmq/_mq_channel.py index ac159e8b28..95daaceea1 100644 --- a/python/fate/arch/federation/rabbitmq/_mq_channel.py +++ b/python/fate/arch/federation/backends/rabbitmq/_mq_channel.py @@ -19,7 +19,7 @@ import pika -from .._nretry import nretry +from fate.arch.federation.message_queue import nretry LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/rabbitmq/_rabbit_manager.py b/python/fate/arch/federation/backends/rabbitmq/_rabbit_manager.py similarity index 100% rename from python/fate/arch/federation/rabbitmq/_rabbit_manager.py rename to python/fate/arch/federation/backends/rabbitmq/_rabbit_manager.py diff --git a/python/fate/arch/federation/standalone/__init__.py b/python/fate/arch/federation/backends/standalone/__init__.py similarity index 100% rename from python/fate/arch/federation/standalone/__init__.py rename to python/fate/arch/federation/backends/standalone/__init__.py diff --git a/python/fate/arch/federation/standalone/_federation.py b/python/fate/arch/federation/backends/standalone/_federation.py similarity index 93% rename from python/fate/arch/federation/standalone/_federation.py rename to python/fate/arch/federation/backends/standalone/_federation.py index 9919f699e3..456c4305f0 100644 --- a/python/fate/arch/federation/standalone/_federation.py +++ b/python/fate/arch/federation/backends/standalone/_federation.py @@ -16,10 +16,9 @@ from typing import List from fate.arch.abc import PartyMeta - -from ... import _standalone as standalone -from ...computing.standalone import Table, CSession -from ..federation import Federation +from fate.arch.computing.standalone import Table, CSession +from fate.arch.federation.api import Federation +from .... import _standalone as standalone LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/message_queue/__init__.py b/python/fate/arch/federation/message_queue/__init__.py new file mode 100644 index 0000000000..5150a759dc --- /dev/null +++ b/python/fate/arch/federation/message_queue/__init__.py @@ -0,0 +1,4 @@ + +from ._federation import MessageQueueBasedFederation +from ._nretry import nretry +from ._parties import Party \ No newline at end of file diff --git a/python/fate/arch/federation/_datastream.py b/python/fate/arch/federation/message_queue/_datastream.py similarity index 100% rename from python/fate/arch/federation/_datastream.py rename to python/fate/arch/federation/message_queue/_datastream.py diff --git a/python/fate/arch/federation/_federation.py b/python/fate/arch/federation/message_queue/_federation.py similarity index 86% rename from python/fate/arch/federation/_federation.py rename to python/fate/arch/federation/message_queue/_federation.py index 024d8f9e12..8056359380 100644 --- a/python/fate/arch/federation/_federation.py +++ b/python/fate/arch/federation/message_queue/_federation.py @@ -24,11 +24,9 @@ from typing import List from fate.arch.abc import CTableABC, PartyMeta - -from ..federation import FederationDataType -from ..federation._datastream import Datastream +from fate.arch.federation.api import Federation, FederationDataType +from ._datastream import Datastream from ._parties import Party -from .federation import Federation LOGGER = logging.getLogger(__name__) @@ -48,7 +46,7 @@ def _get_splits(obj, max_message_size): return kv, num_slice -class FederationBase(Federation): +class MessageQueueBasedFederation(Federation): def __init__( self, session_id, @@ -119,6 +117,9 @@ def _pull_table(self, name: str, tag: str, parties: typing.List[PartyMeta]) -> t rtn = [] dtype = rtn_dtype.get("dtype", None) partitions = rtn_dtype.get("partitions", None) + partitioner_type = rtn_dtype.get("partitioner_type", None) + key_serdes_type = rtn_dtype.get("key_serdes_type", None) + value_serdes_type = rtn_dtype.get("value_serdes_type", None) if dtype == FederationDataType.TABLE: party_topic_infos = self._get_party_topic_infos(_parties, name, partitions=partitions) @@ -138,81 +139,23 @@ def _pull_table(self, name: str, tag: str, parties: typing.List[PartyMeta]) -> t mq=self._mq, conf=self._conf, ) - - table = self.computing_session.parallelize(range(partitions), include_key=False, partition=partitions) - table = table.mapPartitionsWithIndex(receive_func) + table = self.computing_session.parallelize( + range(partitions), + include_key=False, + partition=partitions, + partitioner_type=partitioner_type, + key_serdes_type=key_serdes_type, + value_serdes_type=value_serdes_type, + ) + table = table.mapPartitionsWithIndex( + receive_func, + output_key_serdes_type=key_serdes_type, + output_value_serdes_type=value_serdes_type, + output_partitioner_type=partitioner_type, + ) rtn.append(table) return rtn - def pull(self, name: str, tag: str, parties: typing.List[PartyMeta]) -> typing.List: - # wrap as party - _parties = [Party(role=p[0], party_id=p[1]) for p in parties] - log_str = f"[federation.get](name={name}, tag={tag}, parties={parties})" - LOGGER.debug(f"[{log_str}]start to get") - - _name_dtype_keys = [_SPLIT_.join([party.role, party.party_id, name, tag, "get"]) for party in _parties] - - if _name_dtype_keys[0] not in self._name_dtype_map: - party_topic_infos = self._get_party_topic_infos(_parties, dtype=NAME_DTYPE_TAG) - channel_infos = self._get_channels(party_topic_infos=party_topic_infos) - rtn_dtype = [] - for i, info in enumerate(channel_infos): - obj = self._receive_obj(info, name, tag=_SPLIT_.join([tag, NAME_DTYPE_TAG])) - rtn_dtype.append(obj) - LOGGER.debug(f"[federation.get] _name_dtype_keys: {_name_dtype_keys}, dtype: {obj}") - - for k in _name_dtype_keys: - if k not in self._name_dtype_map: - self._name_dtype_map[k] = rtn_dtype[0] - - rtn_dtype = self._name_dtype_map[_name_dtype_keys[0]] - - rtn = [] - dtype = rtn_dtype.get("dtype", None) - partitions = rtn_dtype.get("partitions", None) - - if dtype == FederationDataType.TABLE or dtype == FederationDataType.SPLIT_OBJECT: - party_topic_infos = self._get_party_topic_infos(_parties, name, partitions=partitions) - for i in range(len(party_topic_infos)): - party = _parties[i] - role = party.role - party_id = party.party_id - topic_infos = party_topic_infos[i] - receive_func = self._get_partition_receive_func( - name=name, - tag=tag, - src_party_id=self.local_party[1], - src_role=self.local_party[0], - dst_party_id=party_id, - dst_role=role, - topic_infos=topic_infos, - mq=self._mq, - conf=self._conf, - ) - - table = self.computing_session.parallelize(range(partitions), partitions, include_key=False) - table = table.mapPartitionsWithIndex(receive_func) - - # add gc - self.get_gc.register_clean_action(name, tag, table, "__del__", {}) - - LOGGER.debug(f"[{log_str}]received table({i + 1}/{len(parties)}), party: {parties[i]} ") - if dtype == FederationDataType.TABLE: - rtn.append(table) - else: - obj_bytes = b"".join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0]))) - obj = p_loads(obj_bytes) - rtn.append(obj) - else: - party_topic_infos = self._get_party_topic_infos(_parties, name) - channel_infos = self._get_channels(party_topic_infos=party_topic_infos) - for i, info in enumerate(channel_infos): - obj = self._receive_obj(info, name, tag) - LOGGER.debug(f"[{log_str}]received obj({i + 1}/{len(parties)}), party: {parties[i]} ") - rtn.append(obj) - - LOGGER.debug(f"[{log_str}]finish to get") - return rtn def _push_bytes( self, @@ -260,7 +203,13 @@ def _push_table(self, table, name: str, tag: str, parties: typing.List[PartyMeta if _name_dtype_keys[0] not in self._name_dtype_map: party_topic_infos = self._get_party_topic_infos(_parties, dtype=NAME_DTYPE_TAG) channel_infos = self._get_channels(party_topic_infos=party_topic_infos) - body = {"dtype": FederationDataType.TABLE, "partitions": table.num_partitions} + body = { + "dtype": FederationDataType.TABLE, + "partitions": table.num_partitions, + "partitioner_type": table.partitioner_type, + "key_serdes_type": table.key_serdes_type, + "value_serdes_type": table.value_serdes_type, + } self._send_obj( name=name, tag=_SPLIT_.join([tag, NAME_DTYPE_TAG]), @@ -685,7 +634,6 @@ def _partition_receive( mq, conf: dict, ): - topic_pair = topic_infos[index][1] channel_info = self._get_channel( topic_pair=topic_pair, diff --git a/python/fate/arch/federation/_nretry.py b/python/fate/arch/federation/message_queue/_nretry.py similarity index 100% rename from python/fate/arch/federation/_nretry.py rename to python/fate/arch/federation/message_queue/_nretry.py diff --git a/python/fate/arch/federation/_parties.py b/python/fate/arch/federation/message_queue/_parties.py similarity index 100% rename from python/fate/arch/federation/_parties.py rename to python/fate/arch/federation/message_queue/_parties.py diff --git a/python/fate/arch/utils/trace.py b/python/fate/arch/utils/trace.py index 149483ce11..703d01eac9 100644 --- a/python/fate/arch/utils/trace.py +++ b/python/fate/arch/utils/trace.py @@ -8,7 +8,7 @@ from opentelemetry import trace, context if typing.TYPE_CHECKING: - from fate.arch.federation.federation import PartyMeta + from fate.arch.abc import PartyMeta from fate.arch.computing.table import KVTable logger = logging.getLogger(__name__) From a1d55ba3a1bc573a783f68ba58a194182336bc0e Mon Sep 17 00:00:00 2001 From: sagewe Date: Sat, 9 Dec 2023 19:03:35 +0800 Subject: [PATCH 23/42] refactor computing code directory Signed-off-by: sagewe --- python/fate/arch/abc/__init__.py | 2 - python/fate/arch/abc/_party.py | 17 - python/fate/arch/abc/_table.py | 633 ------------------ python/fate/arch/computing/__init__.py | 14 +- python/fate/arch/computing/_builder.py | 81 +++ python/fate/arch/computing/api/__init__.py | 4 + .../fate/arch/computing/{ => api}/_profile.py | 2 +- .../computing/{table.py => api/_table.py} | 10 +- python/fate/arch/computing/{ => api}/_type.py | 0 python/fate/arch/computing/api/_uuid.py | 10 + .../fate/arch/computing/backends/__init__.py | 0 .../{ => backends}/eggroll/__init__.py | 0 .../{ => backends}/eggroll/_csession.py | 6 +- .../{ => backends}/eggroll/_table.py | 5 +- .../computing/{ => backends}/eggroll/_type.py | 0 .../{ => backends}/spark/__init__.py | 0 .../{ => backends}/spark/_csession.py | 24 +- .../{ => backends}/spark/_materialize.py | 0 .../computing/{ => backends}/spark/_table.py | 19 +- .../{ => backends}/standalone/__init__.py | 0 .../{ => backends}/standalone/_csession.py | 8 +- .../{ => backends}/standalone/_table.py | 10 +- .../{ => backends}/standalone/_type.py | 0 .../arch/computing/partitioners/__init__.py | 30 + .../partitioners/_integer_partitioner.py | 2 + .../_java_string_like_partitioner.py | 12 + .../partitioners/_mmh3_partitioner.py | 5 + python/fate/arch/context/_context.py | 2 +- python/fate/arch/context/_federation.py | 4 +- python/fate/arch/dataframe/ops/_arithmetic.py | 2 +- .../arch/dataframe/ops/utils/operators.py | 2 +- python/fate/arch/federation/_builder.py | 4 +- python/fate/arch/federation/api/__init__.py | 2 +- .../fate/arch/federation/api/_federation.py | 4 +- python/fate/arch/federation/api/_type.py | 4 + .../backends/eggroll/_federation.py | 7 +- .../federation/backends/osx/_federation.py | 2 +- .../federation/backends/pulsar/_federation.py | 2 +- .../backends/rabbitmq/_federation.py | 4 +- .../backends/standalone/_federation.py | 4 +- .../federation/message_queue/_federation.py | 11 +- .../arch/federation/message_queue/_parties.py | 2 +- .../arch/histogram/_histogram_distributed.py | 4 +- python/fate/arch/launchers/context_helper.py | 6 +- .../arch/launchers/multiprocess_launcher.py | 2 +- .../fate/arch/tensor/distributed/_tensor.py | 6 +- python/fate/arch/unify/__init__.py | 3 +- python/fate/arch/unify/_uuid.py | 6 - python/fate/arch/unify/partitioner.py | 45 -- python/fate/arch/utils/trace.py | 4 +- .../fate/components/core/_load_computing.py | 19 +- .../entrypoint/cli/component/execute_cli.py | 2 +- 52 files changed, 240 insertions(+), 807 deletions(-) delete mode 100644 python/fate/arch/abc/__init__.py delete mode 100644 python/fate/arch/abc/_party.py delete mode 100644 python/fate/arch/abc/_table.py create mode 100644 python/fate/arch/computing/_builder.py create mode 100644 python/fate/arch/computing/api/__init__.py rename python/fate/arch/computing/{ => api}/_profile.py (99%) rename python/fate/arch/computing/{table.py => api/_table.py} (99%) rename python/fate/arch/computing/{ => api}/_type.py (100%) create mode 100644 python/fate/arch/computing/api/_uuid.py create mode 100644 python/fate/arch/computing/backends/__init__.py rename python/fate/arch/computing/{ => backends}/eggroll/__init__.py (100%) rename python/fate/arch/computing/{ => backends}/eggroll/_csession.py (96%) rename python/fate/arch/computing/{ => backends}/eggroll/_table.py (98%) rename python/fate/arch/computing/{ => backends}/eggroll/_type.py (100%) rename python/fate/arch/computing/{ => backends}/spark/__init__.py (100%) rename python/fate/arch/computing/{ => backends}/spark/_csession.py (84%) rename python/fate/arch/computing/{ => backends}/spark/_materialize.py (100%) rename python/fate/arch/computing/{ => backends}/spark/_table.py (96%) rename python/fate/arch/computing/{ => backends}/standalone/__init__.py (100%) rename python/fate/arch/computing/{ => backends}/standalone/_csession.py (96%) rename python/fate/arch/computing/{ => backends}/standalone/_table.py (95%) rename python/fate/arch/computing/{ => backends}/standalone/_type.py (100%) create mode 100644 python/fate/arch/computing/partitioners/__init__.py create mode 100644 python/fate/arch/computing/partitioners/_integer_partitioner.py create mode 100644 python/fate/arch/computing/partitioners/_java_string_like_partitioner.py create mode 100644 python/fate/arch/computing/partitioners/_mmh3_partitioner.py delete mode 100644 python/fate/arch/unify/partitioner.py diff --git a/python/fate/arch/abc/__init__.py b/python/fate/arch/abc/__init__.py deleted file mode 100644 index 5838f25ab6..0000000000 --- a/python/fate/arch/abc/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ._party import PartyMeta -from ._table import CSessionABC, CTableABC diff --git a/python/fate/arch/abc/_party.py b/python/fate/arch/abc/_party.py deleted file mode 100644 index 02ab615046..0000000000 --- a/python/fate/arch/abc/_party.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Literal, Tuple - -PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str] diff --git a/python/fate/arch/abc/_table.py b/python/fate/arch/abc/_table.py deleted file mode 100644 index dee3d993cd..0000000000 --- a/python/fate/arch/abc/_table.py +++ /dev/null @@ -1,633 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -distributed computing -""" - -import abc -import typing -from abc import ABCMeta -from typing import Iterable - -__all__ = ["CTableABC", "CSessionABC"] - -K = typing.TypeVar("K") -V = typing.TypeVar("V") -K_OUT = typing.TypeVar("K_OUT") -V_OUT = typing.TypeVar("V_OUT") - - -# noinspection PyPep8Naming -class CTableABC(typing.Generic[K, V], metaclass=ABCMeta): - """ - a table of pair-like data supports distributed processing - """ - - @property - @abc.abstractmethod - def engine(self): - """ - get the engine name of table - - Returns - ------- - int - number of partitions - """ - ... - - @property - @abc.abstractmethod - def partitions(self) -> int: - """ - get the partitions of table - - Returns - ------- - int - number of partitions - """ - ... - - @abc.abstractmethod - def copy(self): - ... - - @abc.abstractmethod - def save(self, uri, schema: dict, options: dict = None): - """ - save table - - Parameters - ---------- - address: AddressABC - address to save table to - schema: dict - table schema - options: dict - options for saving - """ - ... - - @abc.abstractmethod - def collect(self, **kwargs) -> typing.Generator: - """ - collect data from table - - Returns - ------- - generator - generator of data - - Notes - ------ - no order guarantee - """ - ... - - @abc.abstractmethod - def take(self, n=1, **kwargs) -> typing.List[V]: - """ - take ``n`` data from table - - Parameters - ---------- - n: int - number of data to take - - Returns - ------- - list - a list of ``n`` data - - Notes - ------ - no order guarantee - """ - ... - - @abc.abstractmethod - def first(self, **kwargs) -> V: - """ - take one data from table - - Returns - ------- - object - a data from table - - - Notes - ------- - no order guarantee - """ - ... - - @abc.abstractmethod - def count(self) -> int: - """ - number of data in table - - Returns - ------- - int - number of data - """ - ... - - @abc.abstractmethod - def map(self, func: typing.Callable[[K, V], typing.Tuple[K_OUT, V_OUT]]) -> "CTableABC[K_OUT, V_OUT]": - """ - apply `func` to each data - - Parameters - ---------- - func: ``typing.Callable[[object, object], typing.Tuple[object, object]]`` - function map (k1, v1) to (k2, v2) - - Returns - ------- - CTableABC - A new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([('k1', 1), ('k2', 2), ('k3', 3)], include_key=True, partition=2) - >>> b = a.map(lambda k, v: (k, v**2)) - >>> list(b.collect()) - [("k1", 1), ("k2", 4), ("k3", 9)] - """ - ... - - @abc.abstractmethod - def mapValues(self, func: typing.Callable[[V], V_OUT]) -> "CTableABC[K, V_OUT]": - """ - apply `func` to each value of data - - Parameters - ---------- - func: ``typing.Callable[[object], object]`` - map v1 to v2 - - Returns - ------- - CTableABC - A new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([('a', ['apple', 'banana', 'lemon']), ('b', ['grapes'])], include_key=True, partition=2) - >>> b = a.mapValues(lambda x: len(x)) - >>> list(b.collect()) - [('a', 3), ('b', 1)] - """ - ... - - @abc.abstractmethod - def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False): - """ - apply ``func`` to each partition of table - - Parameters - ---------- - func: ``typing.Callable[[iter], list]`` - accept an iterator of pair, return a list of pair - use_previous_behavior: bool - this parameter is provided for compatible reason, if set True, call this func will call ``applyPartitions`` instead - preserves_partitioning: bool - flag indicate whether the `func` will preserve partition - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([1, 2, 3, 4, 5], include_key=False, partition=2) - >>> def f(iterator): - ... s = 0 - ... for k, v in iterator: - ... s += v - ... return [(s, s)] - ... - >>> b = a.mapPartitions(f) - >>> list(b.collect()) - [(6, 6), (9, 9)] - """ - ... - - @abc.abstractmethod - def mapReducePartitions(self, mapper, reducer, **kwargs): - """ - apply ``mapper`` to each partition of table and then perform reduce by key operation with `reducer` - - Parameters - ---------- - mapper: ``typing.Callable[[iter], list]`` - accept an iterator of pair, return a list of pair - reducer: ``typing.Callable[[object, object], object]`` - reduce v1, v2 to v3 - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> table = computing_session.parallelize([(1, 2), (2, 3), (3, 4), (4, 5)], include_key=False, partition=2) - >>> def _mapper(it): - ... r = [] - ... for k, v in it: - ... r.append((k % 3, v**2)) - ... r.append((k % 2, v ** 3)) - ... return r - >>> def _reducer(a, b): - ... return a + b - >>> output = table.mapReducePartitions(_mapper, _reducer) - >>> collected = dict(output.collect()) - >>> assert collected[0] == 3 ** 3 + 5 ** 3 + 4 ** 2 - >>> assert collected[1] == 2 ** 3 + 4 ** 3 + 2 ** 2 + 5 ** 2 - >>> assert collected[2] == 3 ** 2 - """ - - ... - - def applyPartitions(self, func): - """ - apply ``func`` to each partitions as a single object - - Parameters - ---------- - func: ``typing.Callable[[iter], object]`` - accept a iterator, return a object - - Returns - ------- - CTableABC - a new table, with each partition contains a single key-value pair - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([1, 2, 3], partition=3, include_key=False) - >>> def f(it): - ... r = [] - ... for k, v in it: - ... r.append(v, v**2, v**3) - ... return r - >>> output = a.applyPartitions(f) - >>> assert (2, 2**2, 2**3) in [v[0] for _, v in output.collect()] - """ - ... - - @abc.abstractmethod - def mapPartitionsWithIndex(self, func, preserves_partitioning=False): - ... - - @abc.abstractmethod - def flatMap(self, func): - """ - apply a flat ``func`` to each data of table - - Parameters - ---------- - func: ``typing.Callable[[object, object], typing.List[object, object]]`` - a flat function accept two parameters return a list of pair - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([(1, 1), (2, 2)], include_key=True, partition=2) - >>> b = a.flatMap(lambda x, y: [(x, y), (x + 10, y ** 2)]) - >>> c = list(b.collect()) - >>> assert len(c) = 4 - >>> assert ((1, 1) in c) and ((2, 2) in c) and ((11, 1) in c) and ((12, 4) in c) - """ - ... - - @abc.abstractmethod - def reduce(self, func: typing.Callable[[V, V], V]) -> V: - """ - reduces all value in pair of table by a binary function `func` - - Parameters - ---------- - func: typing.Callable[[object, object], object] - binary function reduce two value into one - - Returns - ------- - object - a single object - - - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize(range(100), include_key=False, partition=4) - >>> assert a.reduce(lambda x, y: x + y) == sum(range(100)) - - Notes - ------ - `func` should be associative - """ - ... - - @abc.abstractmethod - def glom(self): - """ - coalesces all data within partition into a list - - Returns - ------- - list - list containing all coalesced partition and its elements. - First element of each tuple is chosen from key of last element of each partition. - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize(range(5), include_key=False, partition=3).glom().collect() - >>> list(a) - [(2, [(2, 2)]), (3, [(0, 0), (3, 3)]), (4, [(1, 1), (4, 4)])] - """ - ... - - @abc.abstractmethod - def sample( - self, - *, - fraction: typing.Optional[float] = None, - num: typing.Optional[int] = None, - seed=None, - ): - """ - return a sampled subset of this Table. - Parameters - ---------- - fraction: float - Expected size of the sample as a fraction of this table's size - without replacement: probability that each element is chosen. - Fraction must be [0, 1] with replacement: expected number of times each element is chosen. - num: int - Exact number of the sample from this table's size - seed: int - Seed of the random number generator. Use current timestamp when `None` is passed. - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> x = computing_session.parallelize(range(100), include_key=False, partition=4) - >>> 6 <= x.sample(fraction=0.1, seed=81).count() <= 14 - True - - Notes - ------- - use one of ``fraction`` and ``num``, not both - - """ - ... - - @abc.abstractmethod - def filter(self, func): - """ - returns a new table containing only those keys which satisfy a predicate passed in via ``func``. - - Parameters - ---------- - func: typing.Callable[[object, object], bool] - Predicate function returning a boolean. - - Returns - ------- - CTableABC - A new table containing results. - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([0, 1, 2], include_key=False, partition=2) - >>> b = a.filter(lambda k, v : k % 2 == 0) - >>> list(b.collect()) - [(0, 0), (2, 2)] - >>> c = a.filter(lambda k, v : v % 2 != 0) - >>> list(c.collect()) - [(1, 1)] - """ - ... - - @abc.abstractmethod - def join(self, other, func: typing.Callable[[typing.Any, typing.Any], typing.Any]) -> "CTableABC": - """ - returns intersection of this table and the other table. - - function ``func`` will be applied to values of keys that exist in both table. - - Parameters - ---------- - other: CTableABC - another table to be operated with. - func: ``typing.Callable[[object, object], object]`` - the function applying to values whose key exists in both tables. - default using left table's value. - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([1, 2, 3], include_key=False, partition=2) # [(0, 1), (1, 2), (2, 3)] - >>> b = computing_session.parallelize([(1, 1), (2, 2), (3, 3)], include_key=True, partition=2) - >>> c = a.join(b, lambda v1, v2 : v1 + v2) - >>> list(c.collect()) - [(1, 3), (2, 5)] - """ - ... - - @abc.abstractmethod - def union(self, other, func=lambda v1, v2: v1): - """ - returns union of this table and the other table. - - function ``func`` will be applied to values of keys that exist in both table. - - Parameters - ---------- - other: CTableABC - another table to be operated with. - func: ``typing.Callable[[object, object], object]`` - The function applying to values whose key exists in both tables. - default using left table's value. - - Returns - ------- - CTableABC - a new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize([1, 2, 3], include_key=False, partition=2) # [(0, 1), (1, 2), (2, 3)] - >>> b = computing_session.parallelize([(1, 1), (2, 2), (3, 3)], include_key=True, partition=2) - >>> c = a.union(b, lambda v1, v2 : v1 + v2) - >>> list(c.collect()) - [(0, 1), (1, 3), (2, 5), (3, 3)] - """ - ... - - @abc.abstractmethod - def subtractByKey(self, other): - """ - returns a new table containing elements only in this table but not in the other table. - - Parameters - ---------- - other: CTableABC - Another table to be subtractbykey with. - - Returns - ------- - CTableABC - A new table - - Examples - -------- - >>> from fate.arch.session import computing_session - >>> a = computing_session.parallelize(range(10), include_key=False, partition=2) - >>> b = computing_session.parallelize(range(5), include_key=False, partition=2) - >>> c = a.subtractByKey(b) - >>> list(c.collect()) - [(5, 5), (6, 6), (7, 7), (8, 8), (9, 9)] - """ - ... - - @property - def schema(self): - if not hasattr(self, "_schema"): - setattr(self, "_schema", {}) - return getattr(self, "_schema") - - @schema.setter - def schema(self, value): - setattr(self, "_schema", value) - - -class CSessionABC(metaclass=ABCMeta): - """ - computing session to load/create/clean tables - """ - - @abc.abstractmethod - def load(self, uri, schema: dict, options: dict = None) -> CTableABC: - """ - load a table from given address - - Parameters - ---------- - address: AddressABC - address to load table from - schema: dict - schema associate with this table - options: dict - options associate with this table load - - Returns - ------- - CTableABC - a table in memory - - """ - ... - - @abc.abstractmethod - def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs) -> CTableABC: - """ - create table from iterable data - - Parameters - ---------- - data: Iterable - data to create table from - partition: int - number of partitions of created table - include_key: bool - ``True`` for create table directly from data, ``False`` for create table with generated keys start from 0 - - Returns - ------- - CTableABC - a table create from data - - """ - - @abc.abstractmethod - def cleanup(self, name, namespace): - """ - delete table(s) - - Parameters - ---------- - name: str - table name or wildcard character - namespace: str - namespace - """ - - @abc.abstractmethod - def destroy(self): - pass - - @abc.abstractmethod - def stop(self): - pass - - @abc.abstractmethod - def kill(self): - pass - - @property - @abc.abstractmethod - def session_id(self) -> str: - """ - get computing session id - - Returns - ------- - str - computing session id - """ - ... diff --git a/python/fate/arch/computing/__init__.py b/python/fate/arch/computing/__init__.py index 5d00ca5d87..4424dae7e2 100644 --- a/python/fate/arch/computing/__init__.py +++ b/python/fate/arch/computing/__init__.py @@ -13,16 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from ._profile import enable_profile_remote, profile_ends, profile_start -from ._type import ComputingEngine - - -def is_table(v): - from fate.arch.abc import CTableABC - from fate.arch.computing.table import KVTable - - return isinstance(v, CTableABC) or isinstance(v, KVTable) - - -__all__ = ["is_table", "ComputingEngine", "profile_start", "profile_ends"] +from ._builder import ComputingBuilder diff --git a/python/fate/arch/computing/_builder.py b/python/fate/arch/computing/_builder.py new file mode 100644 index 0000000000..b6ad567141 --- /dev/null +++ b/python/fate/arch/computing/_builder.py @@ -0,0 +1,81 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing + +from fate.arch.computing.api import ComputingEngine +from fate.arch.config import cfg + + +class ComputingBuilder: + def __init__( + self, + computing_session_id: str, + ): + self._computing_session_id = computing_session_id + + def build(self, t: ComputingEngine, conf: dict): + if t == ComputingEngine.STANDALONE: + data_dir = cfg.get_option(conf, "computing.standalone.data_dir") + options = cfg.get_option(conf, "computing.standalone.options") + return self.build_standalone(data_dir=data_dir, options=options) + elif t == ComputingEngine.EGGROLL: + host = cfg.get_option(conf, "computing.eggroll.host") + port = cfg.get_option(conf, "computing.eggroll.port") + options = cfg.get_option(conf, "computing.eggroll.options") + config = cfg.get_option(conf, "computing.eggroll.config") + config_options = cfg.get_option(conf, "computing.eggroll.config_options") + config_properties_file = cfg.get_option(conf, "computing.eggroll.config_properties_file") + return self.build_eggroll( + host=host, + port=port, + options=options, + config=config, + config_options=config_options, + config_properties_file=config_properties_file, + ) + elif t == ComputingEngine.SPARK: + return self.build_spark() + else: + raise ValueError(f"computing engine={t} not support") + + def build_standalone(self, data_dir: typing.Optional[str], options=None, logger_config=None): + from fate.arch.computing.backends.standalone import CSession + + return CSession( + session_id=self._computing_session_id, + data_dir=data_dir, + logger_config=logger_config, + options=options, + ) + + def build_eggroll( + self, host: str, port: int, options: dict, config=None, config_options=None, config_properties_file=None + ): + from fate.arch.computing.backends.eggroll import CSession + + return CSession( + session_id=self._computing_session_id, + host=host, + port=port, + options=options, + config=config, + config_options=config_options, + config_properties_file=config_properties_file, + ) + + def build_spark(self): + from fate.arch.computing.backends.spark import CSession + + return CSession(self._computing_session_id) diff --git a/python/fate/arch/computing/api/__init__.py b/python/fate/arch/computing/api/__init__.py new file mode 100644 index 0000000000..ed8c4d935e --- /dev/null +++ b/python/fate/arch/computing/api/__init__.py @@ -0,0 +1,4 @@ +from ._profile import enable_profile_remote, profile_ends, profile_start +from ._table import KVTable, KVTableContext, K, V, is_table +from ._type import ComputingEngine +from ._uuid import generate_computing_uuid diff --git a/python/fate/arch/computing/_profile.py b/python/fate/arch/computing/api/_profile.py similarity index 99% rename from python/fate/arch/computing/_profile.py rename to python/fate/arch/computing/api/_profile.py index 3cafe07df5..bb8ac6b584 100644 --- a/python/fate/arch/computing/_profile.py +++ b/python/fate/arch/computing/api/_profile.py @@ -316,7 +316,7 @@ def profile_ends(): def _pretty_table_str(v): - from ..computing import is_table + from ._table import is_table if is_table(v): return f"Table(partition={v.num_partitions})" diff --git a/python/fate/arch/computing/table.py b/python/fate/arch/computing/api/_table.py similarity index 99% rename from python/fate/arch/computing/table.py rename to python/fate/arch/computing/api/_table.py index 36e8f68628..cf669e4f62 100644 --- a/python/fate/arch/computing/table.py +++ b/python/fate/arch/computing/api/_table.py @@ -1,14 +1,12 @@ import abc import logging import random -import traceback from typing import Any, Callable, Tuple, Iterable, Generic, TypeVar, Optional -from fate.arch.unify.partitioner import get_partitioner_by_type +from fate.arch.computing.partitioners import get_partitioner_by_type from fate.arch.computing.serdes import get_serdes_by_type +from fate.arch.unify import URI from fate.arch.utils.trace import auto_trace -from ..unify import URI -import functools from ._profile import computing_profile as _compute_info logger = logging.getLogger(__name__) @@ -897,3 +895,7 @@ def _subtract_by_key(iter1, iter2): if item1 is not None: yield item1 yield from iter1 + + +def is_table(v): + return isinstance(v, KVTable) diff --git a/python/fate/arch/computing/_type.py b/python/fate/arch/computing/api/_type.py similarity index 100% rename from python/fate/arch/computing/_type.py rename to python/fate/arch/computing/api/_type.py diff --git a/python/fate/arch/computing/api/_uuid.py b/python/fate/arch/computing/api/_uuid.py new file mode 100644 index 0000000000..8cbc18536c --- /dev/null +++ b/python/fate/arch/computing/api/_uuid.py @@ -0,0 +1,10 @@ +from typing import Optional + +from fate.arch.unify import uuid + + +def generate_computing_uuid(session_id: Optional[str] = None): + if session_id is None: + return f"computing_{uuid()}" + else: + return f"{session_id}_computing_{uuid()}" diff --git a/python/fate/arch/computing/backends/__init__.py b/python/fate/arch/computing/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/fate/arch/computing/eggroll/__init__.py b/python/fate/arch/computing/backends/eggroll/__init__.py similarity index 100% rename from python/fate/arch/computing/eggroll/__init__.py rename to python/fate/arch/computing/backends/eggroll/__init__.py diff --git a/python/fate/arch/computing/eggroll/_csession.py b/python/fate/arch/computing/backends/eggroll/_csession.py similarity index 96% rename from python/fate/arch/computing/eggroll/_csession.py rename to python/fate/arch/computing/backends/eggroll/_csession.py index 0e28b39f50..ddc0bb56fd 100644 --- a/python/fate/arch/computing/eggroll/_csession.py +++ b/python/fate/arch/computing/backends/eggroll/_csession.py @@ -16,10 +16,9 @@ import logging -from fate.arch.computing.table import KVTableContext +from fate.arch.computing.api import KVTableContext +from fate.arch.unify import URI, uuid from ._table import Table -from .._profile import computing_profile -from ...unify import URI, uuid try: from eggroll.session import session_init @@ -62,7 +61,6 @@ def get_rpc(self): def session_id(self): return self._session_id - @computing_profile def _load(self, uri: URI, schema: dict, options: dict) -> Table: from ._type import EggRollStoreType diff --git a/python/fate/arch/computing/eggroll/_table.py b/python/fate/arch/computing/backends/eggroll/_table.py similarity index 98% rename from python/fate/arch/computing/eggroll/_table.py rename to python/fate/arch/computing/backends/eggroll/_table.py index 25ef8fd1e7..362cba286d 100644 --- a/python/fate/arch/computing/eggroll/_table.py +++ b/python/fate/arch/computing/backends/eggroll/_table.py @@ -18,10 +18,9 @@ import logging from typing import Callable, Iterable, Any -from ...unify import URI -from .._type import ComputingEngine -from ..table import KVTable from eggroll.computing import RollPair +from fate.arch.computing.api import ComputingEngine, KVTable +from fate.arch.unify import URI LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/computing/eggroll/_type.py b/python/fate/arch/computing/backends/eggroll/_type.py similarity index 100% rename from python/fate/arch/computing/eggroll/_type.py rename to python/fate/arch/computing/backends/eggroll/_type.py diff --git a/python/fate/arch/computing/spark/__init__.py b/python/fate/arch/computing/backends/spark/__init__.py similarity index 100% rename from python/fate/arch/computing/spark/__init__.py rename to python/fate/arch/computing/backends/spark/__init__.py diff --git a/python/fate/arch/computing/spark/_csession.py b/python/fate/arch/computing/backends/spark/_csession.py similarity index 84% rename from python/fate/arch/computing/spark/_csession.py rename to python/fate/arch/computing/backends/spark/_csession.py index 581aeacd1b..7387062a7b 100644 --- a/python/fate/arch/computing/spark/_csession.py +++ b/python/fate/arch/computing/backends/spark/_csession.py @@ -17,9 +17,8 @@ import typing from typing import Iterable -from fate.arch.abc import CSessionABC - -from ...unify import URI +from fate.arch.computing.api import KVTableContext +from fate.arch.unify import URI from ._table import from_hdfs, from_hive, from_localfs, from_rdd if typing.TYPE_CHECKING: @@ -27,7 +26,7 @@ LOGGER = logging.getLogger(__name__) -class CSession(CSessionABC): +class CSession(KVTableContext): """ manage RDDTable """ @@ -35,7 +34,7 @@ class CSession(CSessionABC): def __init__(self, session_id): self._session_id = session_id - def load(self, uri: URI, schema, options: dict = None) -> "Table": + def _load(self, uri: URI, schema, options: dict = None) -> "Table": if not options: options = {} partitions = options.get("partitions", None) @@ -80,12 +79,21 @@ def load(self, uri: URI, schema, options: dict = None) -> "Table": raise NotImplementedError(f"uri type {uri} not supported with spark backend") - def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs): + def _parallelize( + self, + data: Iterable, + total_partitions, + key_serdes, + key_serdes_type, + value_serdes, + value_serdes_type, + partitioner, + partitioner_type, + ): # noinspection PyPackageRequirements from pyspark import SparkContext - _iter = data if include_key else enumerate(data) - rdd = SparkContext.getOrCreate().parallelize(_iter, partition) + rdd = SparkContext.getOrCreate().parallelize(data, total_partitions) return from_rdd(rdd) @property diff --git a/python/fate/arch/computing/spark/_materialize.py b/python/fate/arch/computing/backends/spark/_materialize.py similarity index 100% rename from python/fate/arch/computing/spark/_materialize.py rename to python/fate/arch/computing/backends/spark/_materialize.py diff --git a/python/fate/arch/computing/spark/_table.py b/python/fate/arch/computing/backends/spark/_table.py similarity index 96% rename from python/fate/arch/computing/spark/_table.py rename to python/fate/arch/computing/backends/spark/_table.py index 8c831c6811..936f0cfdcf 100644 --- a/python/fate/arch/computing/spark/_table.py +++ b/python/fate/arch/computing/backends/spark/_table.py @@ -21,13 +21,12 @@ from itertools import chain import pyspark -from fate.arch.abc import CTableABC from pyspark.rddsampler import RDDSamplerBase from scipy.stats import hypergeom -from .._profile import computing_profile -from .._type import ComputingEngine +from fate.arch.computing.api import KVTable, ComputingEngine from ._materialize import materialize, unmaterialize +from .._profile import computing_profile LOGGER = logging.getLogger(__name__) @@ -53,12 +52,18 @@ def hive_to_row(k, v): return Row(key=k, value=pickle.dumps(v).hex()) -class Table(CTableABC): - def __init__(self, rdd): - self._rdd: pyspark.RDD = rdd +class Table(KVTable): + def __init__(self, rdd: pyspark.RDD, key_serdes_type, value_serdes_type, partitioner_type): + self._rdd = rdd self._engine = ComputingEngine.SPARK - self._count = None + super().__init__( + key_serdes_type=key_serdes_type, + value_serdes_type=value_serdes_type, + partitioner_type=partitioner_type, + num_partitions=rdd.getNumPartitions(), + ) + @property def engine(self): diff --git a/python/fate/arch/computing/standalone/__init__.py b/python/fate/arch/computing/backends/standalone/__init__.py similarity index 100% rename from python/fate/arch/computing/standalone/__init__.py rename to python/fate/arch/computing/backends/standalone/__init__.py diff --git a/python/fate/arch/computing/standalone/_csession.py b/python/fate/arch/computing/backends/standalone/_csession.py similarity index 96% rename from python/fate/arch/computing/standalone/_csession.py rename to python/fate/arch/computing/backends/standalone/_csession.py index de89e269e9..e170a79b4f 100644 --- a/python/fate/arch/computing/standalone/_csession.py +++ b/python/fate/arch/computing/backends/standalone/_csession.py @@ -16,12 +16,10 @@ import logging from typing import Optional -from ..table import KVTableContext - -from ..._standalone import Session -from ...unify import URI, generate_computing_uuid, uuid +from fate.arch.computing.api import KVTableContext, generate_computing_uuid +from fate.arch.unify import URI, uuid from ._table import Table -import os.path +from ...._standalone import Session logger = logging.getLogger(__name__) diff --git a/python/fate/arch/computing/standalone/_table.py b/python/fate/arch/computing/backends/standalone/_table.py similarity index 95% rename from python/fate/arch/computing/standalone/_table.py rename to python/fate/arch/computing/backends/standalone/_table.py index 437cabcb2b..5c5ea44637 100644 --- a/python/fate/arch/computing/standalone/_table.py +++ b/python/fate/arch/computing/backends/standalone/_table.py @@ -16,15 +16,14 @@ import logging from typing import Callable, Iterable, Any, Tuple -from ...unify import URI -from .._profile import computing_profile -from .._type import ComputingEngine -from ..table import KVTable, V, K -from ..._standalone import Table as StandaloneTable +from fate.arch.computing.api import ComputingEngine, KVTable, K, V +from fate.arch.unify import URI +from ...._standalone import Table as StandaloneTable LOGGER = logging.getLogger(__name__) + class Table(KVTable): def __init__(self, table: StandaloneTable): self._table = table @@ -121,7 +120,6 @@ def _count(self): def _reduce(self, func, **kwargs): return self._table.reduce(func) - @computing_profile def _save(self, uri: URI, schema, options: dict): if uri.scheme != "standalone": raise ValueError(f"uri scheme `{uri.scheme}` not supported with standalone backend") diff --git a/python/fate/arch/computing/standalone/_type.py b/python/fate/arch/computing/backends/standalone/_type.py similarity index 100% rename from python/fate/arch/computing/standalone/_type.py rename to python/fate/arch/computing/backends/standalone/_type.py diff --git a/python/fate/arch/computing/partitioners/__init__.py b/python/fate/arch/computing/partitioners/__init__.py new file mode 100644 index 0000000000..4b9c381e98 --- /dev/null +++ b/python/fate/arch/computing/partitioners/__init__.py @@ -0,0 +1,30 @@ +def partitioner(hash_func, total_partitions): + def partition(key): + return hash_func(key) % total_partitions + + return partition + + +def get_default_partitioner(): + from ._mmh3_partitioner import mmh3_partitioner + + return mmh3_partitioner + + +def get_partitioner_by_type(partitioner_type: int): + if partitioner_type == 0: + return get_default_partitioner() + elif partitioner_type == 1: + from ._integer_partitioner import integer_partitioner + + return integer_partitioner + elif partitioner_type == 2: + from ._mmh3_partitioner import mmh3_partitioner + + return mmh3_partitioner + elif partitioner_type == 3: + from ._java_string_like_partitioner import _java_string_like_partitioner + + return _java_string_like_partitioner + else: + raise ValueError(f"partitioner type `{partitioner_type}` not supported") diff --git a/python/fate/arch/computing/partitioners/_integer_partitioner.py b/python/fate/arch/computing/partitioners/_integer_partitioner.py new file mode 100644 index 0000000000..a66322049b --- /dev/null +++ b/python/fate/arch/computing/partitioners/_integer_partitioner.py @@ -0,0 +1,2 @@ +def integer_partitioner(key: bytes, total_partitions): + return int.from_bytes(key, "big") % total_partitions diff --git a/python/fate/arch/computing/partitioners/_java_string_like_partitioner.py b/python/fate/arch/computing/partitioners/_java_string_like_partitioner.py new file mode 100644 index 0000000000..98fc94c41c --- /dev/null +++ b/python/fate/arch/computing/partitioners/_java_string_like_partitioner.py @@ -0,0 +1,12 @@ +import hashlib + + +def _java_string_like_partitioner(key, total_partitions): + _key = hashlib.sha1(key).digest() + _key = int.from_bytes(_key, byteorder="little", signed=False) + b, j = -1, 0 + while j < total_partitions: + b = int(j) + _key = ((_key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF + j = float(b + 1) * (float(1 << 31) / float((_key >> 33) + 1)) + return int(b) diff --git a/python/fate/arch/computing/partitioners/_mmh3_partitioner.py b/python/fate/arch/computing/partitioners/_mmh3_partitioner.py new file mode 100644 index 0000000000..8a845c089b --- /dev/null +++ b/python/fate/arch/computing/partitioners/_mmh3_partitioner.py @@ -0,0 +1,5 @@ +def mmh3_partitioner(key: bytes, total_partitions): + import mmh3 + + return mmh3.hash(key) % total_partitions + diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 911616242a..db96eaca27 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -31,7 +31,7 @@ if typing.TYPE_CHECKING: from fate.arch.federation.api import Federation - from fate.arch.computing.table import KVTableContext + from fate.arch.computing.api import KVTableContext class Context: diff --git a/python/fate/arch/context/_federation.py b/python/fate/arch/context/_federation.py index 6fb89511c1..ff02cac545 100644 --- a/python/fate/arch/context/_federation.py +++ b/python/fate/arch/context/_federation.py @@ -19,9 +19,9 @@ import typing from typing import Any, List, Tuple, TypeVar, Union -from fate.arch.abc import PartyMeta +from fate.arch.federation.api import PartyMeta +from fate.arch.computing.api import is_table from ._namespace import NS -from ..computing import is_table logger = logging.getLogger(__name__) T = TypeVar("T") diff --git a/python/fate/arch/dataframe/ops/_arithmetic.py b/python/fate/arch/dataframe/ops/_arithmetic.py index 23c66e211f..82f36047f3 100644 --- a/python/fate/arch/dataframe/ops/_arithmetic.py +++ b/python/fate/arch/dataframe/ops/_arithmetic.py @@ -14,7 +14,7 @@ # limitations under the License. import numpy as np import pandas as pd -from fate.arch.computing import is_table +from fate.arch.computing.api import is_table from .._dataframe import DataFrame from ._promote_types import promote_types from .utils.series_align import series_to_ndarray diff --git a/python/fate/arch/dataframe/ops/utils/operators.py b/python/fate/arch/dataframe/ops/utils/operators.py index 0fffd99268..627ea5c25a 100644 --- a/python/fate/arch/dataframe/ops/utils/operators.py +++ b/python/fate/arch/dataframe/ops/utils/operators.py @@ -1,5 +1,5 @@ import numpy as np -from fate.arch.computing import is_table +from fate.arch.computing.api import is_table def binary_operate(lhs, rhs, op, block_indexes, rhs_block_id=None): diff --git a/python/fate/arch/federation/_builder.py b/python/fate/arch/federation/_builder.py index ab14f0a2eb..55c8098a29 100644 --- a/python/fate/arch/federation/_builder.py +++ b/python/fate/arch/federation/_builder.py @@ -16,7 +16,7 @@ import typing from enum import Enum -from fate.arch.abc import PartyMeta +from fate.arch.federation.api import PartyMeta from fate.arch.config import cfg @@ -111,7 +111,7 @@ def build_osx( max_message_size=options.get("max_message_size"), ) else: - from fate.arch.computing.eggroll import CSession + from fate.arch.computing.backends.eggroll import CSession from fate.arch.federation.backends.eggroll import EggrollFederation if not isinstance(computing_session, CSession): diff --git a/python/fate/arch/federation/api/__init__.py b/python/fate/arch/federation/api/__init__.py index 12191de06d..bd260a8e3e 100644 --- a/python/fate/arch/federation/api/__init__.py +++ b/python/fate/arch/federation/api/__init__.py @@ -1,2 +1,2 @@ from ._federation import Federation -from ._type import FederationDataType +from ._type import FederationDataType, PartyMeta diff --git a/python/fate/arch/federation/api/_federation.py b/python/fate/arch/federation/api/_federation.py index b4f4b088bb..3355ae999c 100644 --- a/python/fate/arch/federation/api/_federation.py +++ b/python/fate/arch/federation/api/_federation.py @@ -17,7 +17,6 @@ import typing from typing import List -from fate.arch.abc import PartyMeta from fate.arch.utils.trace import ( federation_push_table_trace, federation_pull_table_trace, @@ -25,9 +24,10 @@ federation_pull_bytes_trace, ) from ._gc import GarbageCollector +from ._type import PartyMeta if typing.TYPE_CHECKING: - from fate.arch.computing.table import KVTable + from fate.arch.computing.api import KVTable logger = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/api/_type.py b/python/fate/arch/federation/api/_type.py index fec2f95f3f..2c71488100 100644 --- a/python/fate/arch/federation/api/_type.py +++ b/python/fate/arch/federation/api/_type.py @@ -14,6 +14,10 @@ # limitations under the License. # +from typing import Literal, Tuple + +PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str] + class FederationDataType(object): OBJECT = "obj" diff --git a/python/fate/arch/federation/backends/eggroll/_federation.py b/python/fate/arch/federation/backends/eggroll/_federation.py index 155db04de5..e5b7bce16c 100644 --- a/python/fate/arch/federation/backends/eggroll/_federation.py +++ b/python/fate/arch/federation/backends/eggroll/_federation.py @@ -23,14 +23,13 @@ from eggroll.computing import RollPair from eggroll.federation import RollSiteContext -from fate.arch.federation.api import Federation -from fate.arch.abc import PartyMeta -from fate.arch.computing.eggroll import Table +from fate.arch.computing.backends.eggroll import Table +from fate.arch.federation.api import Federation, PartyMeta logger = logging.getLogger(__name__) if typing.TYPE_CHECKING: - from fate.arch.computing.eggroll import CSession + from fate.arch.computing.backends.eggroll import CSession class EggrollFederation(Federation): diff --git a/python/fate/arch/federation/backends/osx/_federation.py b/python/fate/arch/federation/backends/osx/_federation.py index 72f270112a..f65fe2ae35 100644 --- a/python/fate/arch/federation/backends/osx/_federation.py +++ b/python/fate/arch/federation/backends/osx/_federation.py @@ -17,7 +17,7 @@ import typing from logging import getLogger -from fate.arch.abc import PartyMeta +from fate.arch.federation.api import PartyMeta from fate.arch.federation.backends.osx import osx_pb2 from fate.arch.federation.message_queue import MessageQueueBasedFederation from ._mq_channel import MQChannel diff --git a/python/fate/arch/federation/backends/pulsar/_federation.py b/python/fate/arch/federation/backends/pulsar/_federation.py index b20a1a663a..92149b65cd 100644 --- a/python/fate/arch/federation/backends/pulsar/_federation.py +++ b/python/fate/arch/federation/backends/pulsar/_federation.py @@ -17,7 +17,7 @@ import logging from typing import List, Optional -from fate.arch.abc import PartyMeta +from fate.arch.federation.api import PartyMeta from fate.arch.federation.message_queue import MessageQueueBasedFederation from ._mq_channel import ( DEFAULT_CLUSTER, diff --git a/python/fate/arch/federation/backends/rabbitmq/_federation.py b/python/fate/arch/federation/backends/rabbitmq/_federation.py index 219b31ab18..bf2074e2a9 100644 --- a/python/fate/arch/federation/backends/rabbitmq/_federation.py +++ b/python/fate/arch/federation/backends/rabbitmq/_federation.py @@ -18,7 +18,7 @@ from logging import getLogger from typing import List, Optional -from fate.arch.abc import PartyMeta +from fate.arch.federation.api import PartyMeta from fate.arch.federation.message_queue import MessageQueueBasedFederation, Party from ._mq_channel import MQChannel from ._rabbit_manager import RabbitManager @@ -138,7 +138,7 @@ def destroy(self): LOGGER.debug(f"[rabbitmq.cleanup]clean user {self._mq.union_name}.") self._rabbit_manager.delete_user(user=self._mq.union_name) - def _get_vhost(self, party): + def _get_vhost(self, party: PartyMeta): low, high = (self._party, party) if self._party < party else (party, self._party) vhost = f"{self._session_id}-{low.role}-{low.party_id}-{high.role}-{high.party_id}" return vhost diff --git a/python/fate/arch/federation/backends/standalone/_federation.py b/python/fate/arch/federation/backends/standalone/_federation.py index 456c4305f0..c529fd3d22 100644 --- a/python/fate/arch/federation/backends/standalone/_federation.py +++ b/python/fate/arch/federation/backends/standalone/_federation.py @@ -15,9 +15,9 @@ import logging from typing import List -from fate.arch.abc import PartyMeta -from fate.arch.computing.standalone import Table, CSession +from fate.arch.computing.backends.standalone import Table, CSession from fate.arch.federation.api import Federation +from fate.arch.federation.api import PartyMeta from .... import _standalone as standalone LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/federation/message_queue/_federation.py b/python/fate/arch/federation/message_queue/_federation.py index 8056359380..5ddb041857 100644 --- a/python/fate/arch/federation/message_queue/_federation.py +++ b/python/fate/arch/federation/message_queue/_federation.py @@ -23,8 +23,8 @@ from pickle import loads as p_loads from typing import List -from fate.arch.abc import CTableABC, PartyMeta -from fate.arch.federation.api import Federation, FederationDataType +from fate.arch.computing.api import is_table +from fate.arch.federation.api import Federation, FederationDataType, PartyMeta from ._datastream import Datastream from ._parties import Party @@ -68,7 +68,7 @@ def __init__( super().__init__(session_id, party, parties) - # temp + # TODO: remove this self._party = Party(party[0], party[1]) def _pull_bytes(self, name: str, tag: str, parties: typing.List[PartyMeta]) -> typing.List: @@ -156,7 +156,6 @@ def _pull_table(self, name: str, tag: str, parties: typing.List[PartyMeta]) -> t rtn.append(table) return rtn - def _push_bytes( self, v: bytes, @@ -246,7 +245,7 @@ def push(self, v, name: str, tag: str, parties: typing.List[PartyMeta]): party_topic_infos = self._get_party_topic_infos(_parties, dtype=NAME_DTYPE_TAG) channel_infos = self._get_channels(party_topic_infos=party_topic_infos) - if not isinstance(v, CTableABC): + if not is_table(v): v, num_slice = _get_splits(v, self._max_message_size) if num_slice > 1: v = self.computing_session.parallelize(data=v, partition=1, include_key=True) @@ -272,7 +271,7 @@ def push(self, v, name: str, tag: str, parties: typing.List[PartyMeta]): if k not in self._name_dtype_map: self._name_dtype_map[k] = body - if isinstance(v, CTableABC): + if is_table(v): total_size = v.count() partitions = v.partitions LOGGER.debug(f"[{log_str}]start to remote table, total_size={total_size}, partitions={partitions}") diff --git a/python/fate/arch/federation/message_queue/_parties.py b/python/fate/arch/federation/message_queue/_parties.py index 4cb453a9d6..affe9896c7 100644 --- a/python/fate/arch/federation/message_queue/_parties.py +++ b/python/fate/arch/federation/message_queue/_parties.py @@ -39,4 +39,4 @@ def __eq__(self, other): return self.party_id == other.party_id and self.role == other.role def as_tuple(self): - return (self.role, self.party_id) + return self.role, self.party_id diff --git a/python/fate/arch/histogram/_histogram_distributed.py b/python/fate/arch/histogram/_histogram_distributed.py index 46fccdd603..1f55403ef8 100644 --- a/python/fate/arch/histogram/_histogram_distributed.py +++ b/python/fate/arch/histogram/_histogram_distributed.py @@ -4,7 +4,7 @@ import torch -from fate.arch.abc import CTableABC +from fate.arch.computing.api import KVTable from ._histogram_local import Histogram from ._histogram_splits import HistogramSplits from .indexer import HistogramIndexer, Shuffler @@ -28,7 +28,7 @@ def _decrypt(split: HistogramSplits): class DistributedHistogram: def __init__( self, - splits: CTableABC[int, HistogramSplits], + splits: KVTable[int, HistogramSplits], k, node_size, node_data_size, diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index ae60b4ed1e..57544bd8d7 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -26,7 +26,7 @@ class LauncherClusterContextArgs: @dataclass class LauncherContextArguments: - context_type: str = field(default="standalone") + context_type: str = field(default="local") def init_context(computing_session_id: str, federation_session_id: str): @@ -41,7 +41,7 @@ def init_context(computing_session_id: str, federation_session_id: str): def init_local_context(computing_session_id: str, federation_session_id: str): from fate.arch.utils.paths import get_base_dir - from fate.arch.computing.standalone import CSession + from fate.arch.computing.backends.standalone import CSession from fate.arch.federation import FederationBuilder from fate.arch.context import Context @@ -67,7 +67,7 @@ def init_local_context(computing_session_id: str, federation_session_id: str): def init_cluster_context(computing_session_id: str, federation_session_id: str): from fate.arch.federation import FederationBuilder, FederationMode - from fate.arch.computing.eggroll import CSession + from fate.arch.computing.backends.eggroll import CSession from fate.arch.context import Context diff --git a/python/fate/arch/launchers/multiprocess_launcher.py b/python/fate/arch/launchers/multiprocess_launcher.py index aaac027869..b4aa98aa16 100644 --- a/python/fate/arch/launchers/multiprocess_launcher.py +++ b/python/fate/arch/launchers/multiprocess_launcher.py @@ -132,7 +132,7 @@ def _run_process( from fate.arch.utils.logger import set_up_logging from fate.arch.launchers.context_helper import init_context from fate.arch.utils.trace import setup_tracing - from fate.arch.computing._profile import profile_start, profile_ends + from fate.arch.computing.api import profile_start, profile_ends if args.rank >= len(args.parties): raise ValueError(f"rank {args.rank} is out of range {len(args.parties)}") diff --git a/python/fate/arch/tensor/distributed/_tensor.py b/python/fate/arch/tensor/distributed/_tensor.py index f9e99315ad..303e7e0978 100644 --- a/python/fate/arch/tensor/distributed/_tensor.py +++ b/python/fate/arch/tensor/distributed/_tensor.py @@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, TypeVar, cast import torch -from fate.arch.abc import CTableABC +from fate.arch.computing.api import KVTable from fate.arch.context import Context from fate.arch.tensor.phe import PHETensor from fate.arch.utils.trace import auto_trace @@ -187,7 +187,7 @@ def __repr__(self): @classmethod def from_sharding_table( cls, - data: CTableABC, + data: KVTable, shapes: Optional[List[torch.Size]], axis=0, dtype: Optional[torch.dtype] = None, @@ -289,7 +289,7 @@ def __getitem__(self, item): class Shardings: def __init__( self, - data: CTableABC[int, torch.Tensor], + data: KVTable[int, torch.Tensor], shapes: Optional[List[torch.Size]] = None, axis: int = 0, dtype: Optional[torch.dtype] = None, diff --git a/python/fate/arch/unify/__init__.py b/python/fate/arch/unify/__init__.py index 5d41ad0872..56e19cb27b 100644 --- a/python/fate/arch/unify/__init__.py +++ b/python/fate/arch/unify/__init__.py @@ -14,10 +14,9 @@ # limitations under the License. from ._infra_def import Backend, device from ._io import URI -from ._uuid import generate_computing_uuid, uuid +from ._uuid import uuid __all__ = [ - "generate_computing_uuid", "Backend", "device", "uuid", diff --git a/python/fate/arch/unify/_uuid.py b/python/fate/arch/unify/_uuid.py index 7b9c77fa98..07a1c8f9e7 100644 --- a/python/fate/arch/unify/_uuid.py +++ b/python/fate/arch/unify/_uuid.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from uuid import uuid1 @@ -20,8 +19,3 @@ def uuid(): return uuid1().hex -def generate_computing_uuid(session_id: Optional[str] = None): - if session_id is None: - return f"computing_{uuid()}" - else: - return f"{session_id}_computing_{uuid()}" diff --git a/python/fate/arch/unify/partitioner.py b/python/fate/arch/unify/partitioner.py deleted file mode 100644 index 87c8273622..0000000000 --- a/python/fate/arch/unify/partitioner.py +++ /dev/null @@ -1,45 +0,0 @@ -import hashlib - - -def partitioner(hash_func, total_partitions): - def partition(key): - return hash_func(key) % total_partitions - - return partition - - -def integer_partitioner(key: bytes, total_partitions): - return int.from_bytes(key, "big") % total_partitions - - -def mmh3_partitioner(key: bytes, total_partitions): - import mmh3 - - return mmh3.hash(key) % total_partitions - - -def _java_string_like_partitioner(key, total_partitions): - _key = hashlib.sha1(key).digest() - _key = int.from_bytes(_key, byteorder="little", signed=False) - b, j = -1, 0 - while j < total_partitions: - b = int(j) - _key = ((_key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF - j = float(b + 1) * (float(1 << 31) / float((_key >> 33) + 1)) - return int(b) - - -def get_default_partitioner(): - return mmh3_partitioner - # return _java_string_like_partitioner - - -def get_partitioner_by_type(partitioner_type: int): - if partitioner_type == 0: - return get_default_partitioner() - elif partitioner_type == 1: - return integer_partitioner - elif partitioner_type == 2: - return mmh3_partitioner - else: - raise ValueError(f"partitioner type `{partitioner_type}` not supported") diff --git a/python/fate/arch/utils/trace.py b/python/fate/arch/utils/trace.py index 703d01eac9..924d3e31e4 100644 --- a/python/fate/arch/utils/trace.py +++ b/python/fate/arch/utils/trace.py @@ -8,8 +8,8 @@ from opentelemetry import trace, context if typing.TYPE_CHECKING: - from fate.arch.abc import PartyMeta - from fate.arch.computing.table import KVTable + from fate.arch.federation.api import PartyMeta + from fate.arch.computing.api import KVTable logger = logging.getLogger(__name__) _ENABLE_TRACING = None diff --git a/python/fate/components/core/_load_computing.py b/python/fate/components/core/_load_computing.py index b9d929394a..7391ae3269 100644 --- a/python/fate/components/core/_load_computing.py +++ b/python/fate/components/core/_load_computing.py @@ -18,31 +18,26 @@ def load_computing(computing, logger_config=None): SparkComputingSpec, StandaloneComputingSpec, ) + from fate.arch.computing import ComputingBuilder - if isinstance(computing, StandaloneComputingSpec): - from fate.arch.computing.standalone import CSession + builder = ComputingBuilder(computing.metadata.computing_id) - return CSession( - session_id=computing.metadata.computing_id, + if isinstance(computing, StandaloneComputingSpec): + return builder.build_standalone( data_dir=computing.metadata.options.get("data_dir", None), logger_config=logger_config, options=computing.metadata.options, ) if isinstance(computing, EggrollComputingSpec): - from fate.arch.computing.eggroll import CSession - - return CSession( - computing.metadata.computing_id, + return builder.build_eggroll( host=computing.metadata.host, port=computing.metadata.port, + options=computing.metadata.options, config_options=computing.metadata.config_options, config_properties_file=computing.metadata.config_properties_file, - options=computing.metadata.options, ) if isinstance(computing, SparkComputingSpec): - from fate.arch.computing.spark import CSession - - return CSession(computing.metadata.computing_id) + return builder.build_spark() # TODO: load from plugin raise ValueError(f"conf.computing={computing} not support") diff --git a/python/fate/components/entrypoint/cli/component/execute_cli.py b/python/fate/components/entrypoint/cli/component/execute_cli.py index bbeb619977..44bf520cbf 100644 --- a/python/fate/components/entrypoint/cli/component/execute_cli.py +++ b/python/fate/components/entrypoint/cli/component/execute_cli.py @@ -88,7 +88,7 @@ def execute_component_from_config(config: "TaskConfigSpec", output_path): import traceback from fate.arch import CipherKit, Context - from fate.arch.computing import profile_ends, profile_start + from fate.arch.computing.api import profile_ends, profile_start from fate.components.core import ( ComponentExecutionIO, Role, From 3de18254c59153dfe8b81b77f2c6370e00f68d09 Mon Sep 17 00:00:00 2001 From: sagewe Date: Sat, 9 Dec 2023 19:07:23 +0800 Subject: [PATCH 24/42] fix computing and federation import for tests Signed-off-by: sagewe --- python/fate/ml/aggregator/test/test_aggregator.py | 4 ++-- python/fate/ml/aggregator/test/test_fate_utils.py | 4 ++-- python/fate/ml/ensemble/algo/secureboost/test/test_goss.py | 4 ++-- .../ensemble/algo/secureboost/test/test_hetero_sbt_binary.py | 4 ++-- .../secureboost/test/test_hetero_sbt_binary_multi_host.py | 4 ++-- .../algo/secureboost/test/test_hetero_sbt_binary_with_goss.py | 4 ++-- .../ensemble/algo/secureboost/test/test_hetero_sbt_multi.py | 4 ++-- .../algo/secureboost/test/test_hetero_sbt_regression.py | 4 ++-- .../ensemble/learner/decision_tree/test/test_decision_tree.py | 4 ++-- .../learner/decision_tree/test/test_local_decision_tree.py | 4 ++-- .../learner/decision_tree/tree_core/test/test_loss.py | 4 ++-- python/fate/ml/glm/homo/lr/test/test_fed_lr.py | 4 ++-- python/fate/ml/glm/homo/lr/test/test_local.py | 4 ++-- python/fate/ml/nn/test/test_agglayer.py | 4 ++-- python/fate/ml/nn/test/test_fedpass_alexnet.py | 2 +- python/fate/ml/nn/test/test_fedpass_lenet.py | 4 ++-- python/fate/ml/nn/test/test_fedpass_tabular.py | 4 ++-- python/fate/ml/nn/test/test_hetero_nn_algo.py | 4 ++-- python/fate/ml/nn/test/test_hetero_nn_algo_no_guest.py | 4 ++-- python/fate/ml/nn/test/test_hetero_nn_algo_val.py | 4 ++-- python/fate/ml/nn/test/test_homo_nn_binary.py | 4 ++-- python/fate/ml/utils/test/test_predict_format.py | 4 ++-- python/fate/test/test_dtensor.py | 4 ++-- python/fate/test/test_matmul.py | 4 ++-- 24 files changed, 47 insertions(+), 47 deletions(-) diff --git a/python/fate/ml/aggregator/test/test_aggregator.py b/python/fate/ml/aggregator/test/test_aggregator.py index ba74c743c6..4048f92620 100644 --- a/python/fate/ml/aggregator/test/test_aggregator.py +++ b/python/fate/ml/aggregator/test/test_aggregator.py @@ -10,8 +10,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() logger.setLevel(logging.DEBUG) diff --git a/python/fate/ml/aggregator/test/test_fate_utils.py b/python/fate/ml/aggregator/test/test_fate_utils.py index 51ddd34311..2fc1ccade4 100644 --- a/python/fate/ml/aggregator/test/test_fate_utils.py +++ b/python/fate/ml/aggregator/test/test_fate_utils.py @@ -8,8 +8,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() logger.setLevel(logging.DEBUG) diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_goss.py b/python/fate/ml/ensemble/algo/secureboost/test/test_goss.py index 4e1172b913..4b4451180c 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_goss.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_goss.py @@ -15,8 +15,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py index efc311fbf9..3ca1dc525c 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary.py @@ -15,8 +15,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_multi_host.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_multi_host.py index d364c49c40..fb2c3c5bfb 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_multi_host.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_multi_host.py @@ -17,8 +17,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_with_goss.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_with_goss.py index b371f419ed..aef5df671d 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_with_goss.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_binary_with_goss.py @@ -16,8 +16,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py index da1edb4cc6..39c13313f8 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_multi.py @@ -19,8 +19,8 @@ def get_current_datetime_str(): def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py index 1c9feec29e..5bd665d861 100644 --- a/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py +++ b/python/fate/ml/ensemble/algo/secureboost/test/test_hetero_sbt_regression.py @@ -19,8 +19,8 @@ def get_current_datetime_str(): def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py b/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py index 941a70341b..ffb75c1036 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py +++ b/python/fate/ml/ensemble/learner/decision_tree/test/test_decision_tree.py @@ -16,8 +16,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/ensemble/learner/decision_tree/test/test_local_decision_tree.py b/python/fate/ml/ensemble/learner/decision_tree/test/test_local_decision_tree.py index 2dbadeb393..fa8014f103 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/test/test_local_decision_tree.py +++ b/python/fate/ml/ensemble/learner/decision_tree/test/test_local_decision_tree.py @@ -14,8 +14,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py b/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py index ee34552976..7f1f602783 100644 --- a/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py +++ b/python/fate/ml/ensemble/learner/decision_tree/tree_core/test/test_loss.py @@ -21,8 +21,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/glm/homo/lr/test/test_fed_lr.py b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py index 581ecff9a1..00e03b4440 100644 --- a/python/fate/ml/glm/homo/lr/test/test_fed_lr.py +++ b/python/fate/ml/glm/homo/lr/test/test_fed_lr.py @@ -13,8 +13,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() logger.setLevel(logging.DEBUG) diff --git a/python/fate/ml/glm/homo/lr/test/test_local.py b/python/fate/ml/glm/homo/lr/test/test_local.py index 70a60d576e..c6f5fb0d3a 100644 --- a/python/fate/ml/glm/homo/lr/test/test_local.py +++ b/python/fate/ml/glm/homo/lr/test/test_local.py @@ -1,7 +1,7 @@ from fate.arch import Context -from fate.arch.computing.standalone import CSession +from fate.arch.computing.backends.standalone import CSession from fate.arch.context import Context -from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.federation.backends.standalone import StandaloneFederation import pandas as pd from fate.arch.dataframe import PandasReader from fate.ml.nn.dataset.table import TableDataset diff --git a/python/fate/ml/nn/test/test_agglayer.py b/python/fate/ml/nn/test/test_agglayer.py index 4d2ec2a4d2..5d97996839 100644 --- a/python/fate/ml/nn/test/test_agglayer.py +++ b/python/fate/ml/nn/test/test_agglayer.py @@ -14,8 +14,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_fedpass_alexnet.py b/python/fate/ml/nn/test/test_fedpass_alexnet.py index 68ac4b58ac..bf9287676e 100644 --- a/python/fate/ml/nn/test/test_fedpass_alexnet.py +++ b/python/fate/ml/nn/test/test_fedpass_alexnet.py @@ -25,7 +25,7 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession + from fate.arch.computing.backends.standalone import CSession from fate.arch.federation.standalone import StandaloneFederation import logging diff --git a/python/fate/ml/nn/test/test_fedpass_lenet.py b/python/fate/ml/nn/test/test_fedpass_lenet.py index 3e35924164..f11bbc8148 100644 --- a/python/fate/ml/nn/test/test_fedpass_lenet.py +++ b/python/fate/ml/nn/test/test_fedpass_lenet.py @@ -24,8 +24,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_fedpass_tabular.py b/python/fate/ml/nn/test/test_fedpass_tabular.py index db2b3cca54..416da6e21c 100644 --- a/python/fate/ml/nn/test/test_fedpass_tabular.py +++ b/python/fate/ml/nn/test/test_fedpass_tabular.py @@ -16,8 +16,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_hetero_nn_algo.py b/python/fate/ml/nn/test/test_hetero_nn_algo.py index 1c27774f49..bcdc891278 100644 --- a/python/fate/ml/nn/test/test_hetero_nn_algo.py +++ b/python/fate/ml/nn/test/test_hetero_nn_algo.py @@ -17,8 +17,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_hetero_nn_algo_no_guest.py b/python/fate/ml/nn/test/test_hetero_nn_algo_no_guest.py index 56c27196f1..9dadd73cb6 100644 --- a/python/fate/ml/nn/test/test_hetero_nn_algo_no_guest.py +++ b/python/fate/ml/nn/test/test_hetero_nn_algo_no_guest.py @@ -17,8 +17,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_hetero_nn_algo_val.py b/python/fate/ml/nn/test/test_hetero_nn_algo_val.py index b86e198b10..b246d90a38 100644 --- a/python/fate/ml/nn/test/test_hetero_nn_algo_val.py +++ b/python/fate/ml/nn/test/test_hetero_nn_algo_val.py @@ -17,8 +17,8 @@ def get_current_datetime_str(): def create_ctx(local, context_name): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging # prepare log diff --git a/python/fate/ml/nn/test/test_homo_nn_binary.py b/python/fate/ml/nn/test/test_homo_nn_binary.py index 94385c3b8b..ce73fc3b4a 100644 --- a/python/fate/ml/nn/test/test_homo_nn_binary.py +++ b/python/fate/ml/nn/test/test_homo_nn_binary.py @@ -12,8 +12,8 @@ def create_ctx(local): from fate.arch import Context - from fate.arch.computing.standalone import CSession - from fate.arch.federation.standalone import StandaloneFederation + from fate.arch.computing.backends.standalone import CSession + from fate.arch.federation.backends.standalone import StandaloneFederation import logging logger = logging.getLogger() diff --git a/python/fate/ml/utils/test/test_predict_format.py b/python/fate/ml/utils/test/test_predict_format.py index 8281388ad2..a6a1d8a3b4 100644 --- a/python/fate/ml/utils/test/test_predict_format.py +++ b/python/fate/ml/utils/test/test_predict_format.py @@ -1,7 +1,7 @@ from fate.arch import Context -from fate.arch.computing.standalone import CSession +from fate.arch.computing.backends.standalone import CSession from fate.arch.context import Context -from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.federation.backends.standalone import StandaloneFederation import pandas as pd from fate.ml.utils.predict_tools import compute_predict_details, PREDICT_SCORE, LABEL, BINARY, REGRESSION, MULTI from fate.arch.dataframe import PandasReader diff --git a/python/fate/test/test_dtensor.py b/python/fate/test/test_dtensor.py index d83d7ff802..acf502e83e 100644 --- a/python/fate/test/test_dtensor.py +++ b/python/fate/test/test_dtensor.py @@ -1,8 +1,8 @@ import pytest import torch from fate.arch import Context -from fate.arch.computing.standalone import CSession -from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.computing.backends.standalone import CSession +from fate.arch.federation.backends.standalone import StandaloneFederation from fate.arch.tensor import DTensor from pytest import fixture diff --git a/python/fate/test/test_matmul.py b/python/fate/test/test_matmul.py index 319086c8fb..60bef1b89c 100644 --- a/python/fate/test/test_matmul.py +++ b/python/fate/test/test_matmul.py @@ -1,7 +1,7 @@ import torch -from fate.arch.computing.standalone import CSession +from fate.arch.computing.backends.standalone import CSession from fate.arch.context import Context -from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.federation.backends.standalone import StandaloneFederation from fate.arch.tensor import DTensor from pytest import fixture From 6a163a8206abe2769b8d75820a8720c38231f383 Mon Sep 17 00:00:00 2001 From: sagewe Date: Sat, 9 Dec 2023 19:08:14 +0800 Subject: [PATCH 25/42] fix computing and federation import for doc Signed-off-by: sagewe --- doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb | 4 ++-- doc/2.0/fate/ml/homo_nn_tutorial.ipynb | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb b/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb index 3f2476ca06..491d075d3e 100644 --- a/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb +++ b/doc/2.0/fate/ml/hetero_secureboost_tutorial.ipynb @@ -71,8 +71,8 @@ "\n", "def create_ctx(local, context_name):\n", " from fate.arch import Context\n", - " from fate.arch.computing.standalone import CSession\n", - " from fate.arch.federation.standalone import StandaloneFederation\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", " import logging\n", "\n", " # prepare log\n", diff --git a/doc/2.0/fate/ml/homo_nn_tutorial.ipynb b/doc/2.0/fate/ml/homo_nn_tutorial.ipynb index 946ac272ff..1589f38094 100644 --- a/doc/2.0/fate/ml/homo_nn_tutorial.ipynb +++ b/doc/2.0/fate/ml/homo_nn_tutorial.ipynb @@ -351,8 +351,8 @@ "\n", "def create_ctx(local):\n", " from fate.arch import Context\n", - " from fate.arch.computing.standalone import CSession\n", - " from fate.arch.federation.standalone import StandaloneFederation\n", + " from fate.arch.computing.backends.standalone import CSession\n", + " from fate.arch.federation.backends.standalone import StandaloneFederation\n", " import logging\n", "\n", " logger = logging.getLogger()\n", @@ -450,7 +450,7 @@ "Guest Terminal Outputs:\n", "\n", "```\n", - "2023-09-13 20:36:13,319 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('guest', 10000)\n", + "2023-09-13 20:36:13,319 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('guest', 10000)\n", "2023-09-13 20:36:13,319 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", "2023-09-13 20:36:13,400 - fate.ml.nn.algo.homo.fedavg - INFO - Using secure_aggregate aggregator\n", "2023-09-13 20:36:13,400 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote data, type=\n", @@ -537,7 +537,7 @@ "Host Terminal Outputs:\n", "\n", "```\n", - "2023-09-13 20:36:12,803 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('host', 9999)\n", + "2023-09-13 20:36:12,803 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('host', 9999)\n", "2023-09-13 20:36:12,803 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", "2023-09-13 20:36:12,888 - fate.ml.nn.algo.homo.fedavg - INFO - Using secure_aggregate aggregator\n", "2023-09-13 20:36:12,888 - fate.arch._standalone - DEBUG - [federation.standalone.remote.agg_type.default]remote data, type=\n", @@ -614,7 +614,7 @@ "Arbiter Terminal Outputs:\n", "\n", "```\n", - "2023-09-13 20:36:12,315 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('arbiter', 10000)\n", + "2023-09-13 20:36:12,315 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation: standalone_session=, federation_session_id=2023-09-13-20-36, party=('arbiter', 10000)\n", "2023-09-13 20:36:12,316 - fate.arch.federation.standalone._federation - DEBUG - [federation.standalone]init federation context done\n", "2023-09-13 20:36:12,316 - fate.arch._standalone - DEBUG - [federation.standalone.get.agg_type.default]\n", "2023-09-13 20:36:13,418 - fate.arch._standalone - DEBUG - [GET] Got 2023-09-13-20-36-agg_type-default-guest-10000-arbiter-10000 type Object\n", From 72ce35153f3e656d056c0c1a01636e7a1e3c4c34 Mon Sep 17 00:00:00 2001 From: sagewe Date: Sun, 10 Dec 2023 12:00:40 +0800 Subject: [PATCH 26/42] refactor trace Signed-off-by: sagewe --- launchers/launcher.py | 9 +++++---- launchers/sshe_lr_launcher.py | 3 ++- python/fate/arch/__init__.py | 4 ++-- python/fate/arch/computing/api/_table.py | 2 +- .../arch/computing/backends/standalone/__init__.py | 3 ++- .../computing/backends/standalone/_csession.py | 2 +- .../backends/standalone}/_standalone.py | 2 +- .../arch/computing/backends/standalone/_table.py | 2 +- python/fate/arch/context/_context.py | 3 +-- python/fate/arch/dataframe/_dataframe.py | 2 +- python/fate/arch/dataframe/_frame_reader.py | 2 +- python/fate/arch/federation/api/_federation.py | 2 +- .../federation/backends/standalone/_federation.py | 8 ++++---- python/fate/arch/launchers/context_helper.py | 2 +- python/fate/arch/{utils => launchers}/logger.py | 0 .../fate/arch/launchers/multiprocess_launcher.py | 10 +++++----- python/fate/arch/{utils => launchers}/paths.py | 0 .../arch/protocol/mpc/communicator/communicator.py | 2 +- .../fate/arch/protocol/mpc/nn/sshe/linr_layer.py | 2 +- python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py | 2 +- python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py | 2 +- python/fate/arch/protocol/mpc/primitives/sshe.py | 2 +- python/fate/arch/tensor/distributed/_tensor.py | 2 +- python/fate/arch/tensor/phe/_tensor.py | 2 +- python/fate/arch/trace/__init__.py | 14 ++++++++++++++ .../fate/arch/{utils/trace.py => trace/_trace.py} | 5 +++++ python/fate/arch/unify/__init__.py | 3 +-- python/fate/arch/unify/_infra_def.py | 6 ------ python/fate/arch/utils/__init__.py | 0 29 files changed, 56 insertions(+), 42 deletions(-) rename python/fate/arch/{ => computing/backends/standalone}/_standalone.py (99%) rename python/fate/arch/{utils => launchers}/logger.py (100%) rename python/fate/arch/{utils => launchers}/paths.py (100%) create mode 100644 python/fate/arch/trace/__init__.py rename python/fate/arch/{utils/trace.py => trace/_trace.py} (97%) delete mode 100644 python/fate/arch/utils/__init__.py diff --git a/launchers/launcher.py b/launchers/launcher.py index 78f8178da8..84a96ab544 100644 --- a/launchers/launcher.py +++ b/launchers/launcher.py @@ -1,6 +1,7 @@ -import click import importlib +import click + @click.command() @click.option("--csession_id", type=str, help="computing session id") @@ -10,8 +11,8 @@ @click.option("--data_dir", type=str, help="data dir") @click.option("--proc", type=str, help="proc, e.g. fate.ml.mpc.svm:SVM", required=True) def cli(csession_id, federation_session_id, rank, parties, data_dir, proc): - from fate.arch.utils.logger import set_up_logging - from fate.arch.utils.context_helper import init_standalone_context + from fate.arch.launchers.logger import set_up_logging + from fate.arch.launchers.context_helper import init_local_context # set up logging set_up_logging(rank) @@ -23,7 +24,7 @@ def cli(csession_id, federation_session_id, rank, parties, data_dir, proc): party = parties[rank] if not csession_id: csession_id = f"{federation_session_id}_{party[0]}_{party[1]}" - ctx = init_standalone_context(csession_id, federation_session_id, party, parties, data_dir) + ctx = init_local_context(csession_id, federation_session_id, party, parties, data_dir) # init crypten from fate.ml.mpc import MPCModule diff --git a/launchers/sshe_lr_launcher.py b/launchers/sshe_lr_launcher.py index 202e2839f3..b35aab7b81 100644 --- a/launchers/sshe_lr_launcher.py +++ b/launchers/sshe_lr_launcher.py @@ -1,4 +1,5 @@ import logging +import pprint import typing from dataclasses import dataclass, field @@ -46,7 +47,7 @@ def run_sshe_lr(ctx: "Context"): } input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data) inst.fit(ctx, train_data=input_data) - print(inst.get_model()) + logger.info(f"model: {pprint.pformat(inst.get_model())}") if __name__ == "__main__": diff --git a/python/fate/arch/__init__.py b/python/fate/arch/__init__.py index 714de7533b..8d939f5966 100644 --- a/python/fate/arch/__init__.py +++ b/python/fate/arch/__init__.py @@ -15,6 +15,6 @@ # from .context import CipherKit, Context -from .unify import URI, Backend, device +from .unify import URI, device -__all__ = ["Backend", "device", "Context", "URI", "CipherKit"] +__all__ = ["device", "Context", "URI", "CipherKit"] diff --git a/python/fate/arch/computing/api/_table.py b/python/fate/arch/computing/api/_table.py index cf669e4f62..60f1ce67a3 100644 --- a/python/fate/arch/computing/api/_table.py +++ b/python/fate/arch/computing/api/_table.py @@ -5,8 +5,8 @@ from fate.arch.computing.partitioners import get_partitioner_by_type from fate.arch.computing.serdes import get_serdes_by_type +from fate.arch.trace import auto_trace from fate.arch.unify import URI -from fate.arch.utils.trace import auto_trace from ._profile import computing_profile as _compute_info logger = logging.getLogger(__name__) diff --git a/python/fate/arch/computing/backends/standalone/__init__.py b/python/fate/arch/computing/backends/standalone/__init__.py index b37a08c0a0..41a9efd3cb 100644 --- a/python/fate/arch/computing/backends/standalone/__init__.py +++ b/python/fate/arch/computing/backends/standalone/__init__.py @@ -14,7 +14,8 @@ # limitations under the License. # +from . import _standalone as standalone_raw from ._csession import CSession from ._table import Table -__all__ = ["Table", "CSession"] +__all__ = ["Table", "CSession", "standalone_raw"] diff --git a/python/fate/arch/computing/backends/standalone/_csession.py b/python/fate/arch/computing/backends/standalone/_csession.py index e170a79b4f..7beee710a0 100644 --- a/python/fate/arch/computing/backends/standalone/_csession.py +++ b/python/fate/arch/computing/backends/standalone/_csession.py @@ -18,8 +18,8 @@ from fate.arch.computing.api import KVTableContext, generate_computing_uuid from fate.arch.unify import URI, uuid +from ._standalone import Session from ._table import Table -from ...._standalone import Session logger = logging.getLogger(__name__) diff --git a/python/fate/arch/_standalone.py b/python/fate/arch/computing/backends/standalone/_standalone.py similarity index 99% rename from python/fate/arch/_standalone.py rename to python/fate/arch/computing/backends/standalone/_standalone.py index 8474a9ca56..32e29b838b 100644 --- a/python/fate/arch/_standalone.py +++ b/python/fate/arch/computing/backends/standalone/_standalone.py @@ -36,7 +36,7 @@ import cloudpickle as f_pickle import lmdb -from fate.arch.utils import trace +from fate.arch import trace PartyMeta = Tuple[Literal["guest", "host", "arbiter", "local"], str] diff --git a/python/fate/arch/computing/backends/standalone/_table.py b/python/fate/arch/computing/backends/standalone/_table.py index 5c5ea44637..ae6d79b4e7 100644 --- a/python/fate/arch/computing/backends/standalone/_table.py +++ b/python/fate/arch/computing/backends/standalone/_table.py @@ -18,7 +18,7 @@ from fate.arch.computing.api import ComputingEngine, KVTable, K, V from fate.arch.unify import URI -from ...._standalone import Table as StandaloneTable +from ._standalone import Table as StandaloneTable LOGGER = logging.getLogger(__name__) diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index db96eaca27..375a2a0399 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -17,13 +17,12 @@ import typing from typing import Iterable, Literal, Optional, Tuple, TypeVar, overload +from fate.arch.trace import auto_trace from ._cipher import CipherKit from ._federation import Parties, Party from ._metrics import InMemoryMetricsHandler, MetricsWrap from ._namespace import NS, default_ns from ..unify import device -from fate.arch.utils.trace import auto_trace -from fate.arch.config import cfg logger = logging.getLogger(__name__) diff --git a/python/fate/arch/dataframe/_dataframe.py b/python/fate/arch/dataframe/_dataframe.py index 7453c311c3..9b76000777 100644 --- a/python/fate/arch/dataframe/_dataframe.py +++ b/python/fate/arch/dataframe/_dataframe.py @@ -23,7 +23,7 @@ from fate.arch.tensor import DTensor from .manager import DataManager, Schema -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace if typing.TYPE_CHECKING: from fate.arch.histogram import DistributedHistogram, HistogramBuilder diff --git a/python/fate/arch/dataframe/_frame_reader.py b/python/fate/arch/dataframe/_frame_reader.py index 9828b3f929..95025a9ebf 100644 --- a/python/fate/arch/dataframe/_frame_reader.py +++ b/python/fate/arch/dataframe/_frame_reader.py @@ -21,7 +21,7 @@ from .entity import types from ._dataframe import DataFrame from .manager import DataManager -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace class TableReader(object): diff --git a/python/fate/arch/federation/api/_federation.py b/python/fate/arch/federation/api/_federation.py index 3355ae999c..239fcac316 100644 --- a/python/fate/arch/federation/api/_federation.py +++ b/python/fate/arch/federation/api/_federation.py @@ -17,7 +17,7 @@ import typing from typing import List -from fate.arch.utils.trace import ( +from fate.arch.trace import ( federation_push_table_trace, federation_pull_table_trace, federation_push_bytes_trace, diff --git a/python/fate/arch/federation/backends/standalone/_federation.py b/python/fate/arch/federation/backends/standalone/_federation.py index c529fd3d22..2f176f4e98 100644 --- a/python/fate/arch/federation/backends/standalone/_federation.py +++ b/python/fate/arch/federation/backends/standalone/_federation.py @@ -16,9 +16,9 @@ from typing import List from fate.arch.computing.backends.standalone import Table, CSession +from fate.arch.computing.backends.standalone import standalone_raw from fate.arch.federation.api import Federation from fate.arch.federation.api import PartyMeta -from .... import _standalone as standalone LOGGER = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def __init__( parties: List[PartyMeta], ): super().__init__(federation_session_id, party, parties) - self._federation = standalone.Federation.create( + self._federation = standalone_raw.Federation.create( standalone_session.get_standalone_session(), session_id=federation_session_id, party=party ) @@ -62,7 +62,7 @@ def _pull_table( ) -> List[Table]: rtn = self._federation.pull_table(name=name, tag=tag, parties=parties) - return [Table(r) if isinstance(r, standalone.Table) else r for r in rtn] + return [Table(r) if isinstance(r, standalone_raw.Table) else r for r in rtn] def _pull_bytes( self, @@ -72,7 +72,7 @@ def _pull_bytes( ) -> List[bytes]: rtn = self._federation.pull_bytes(name=name, tag=tag, parties=parties) - return [Table(r) if isinstance(r, standalone.Table) else r for r in rtn] + return [Table(r) if isinstance(r, standalone_raw.Table) else r for r in rtn] def destroy(self): self._federation.destroy() diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index 57544bd8d7..4a4fd9e487 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -40,7 +40,7 @@ def init_context(computing_session_id: str, federation_session_id: str): def init_local_context(computing_session_id: str, federation_session_id: str): - from fate.arch.utils.paths import get_base_dir + from .paths import get_base_dir from fate.arch.computing.backends.standalone import CSession from fate.arch.federation import FederationBuilder from fate.arch.context import Context diff --git a/python/fate/arch/utils/logger.py b/python/fate/arch/launchers/logger.py similarity index 100% rename from python/fate/arch/utils/logger.py rename to python/fate/arch/launchers/logger.py diff --git a/python/fate/arch/launchers/multiprocess_launcher.py b/python/fate/arch/launchers/multiprocess_launcher.py index b4aa98aa16..94966c326a 100644 --- a/python/fate/arch/launchers/multiprocess_launcher.py +++ b/python/fate/arch/launchers/multiprocess_launcher.py @@ -36,7 +36,7 @@ import rich.panel import rich.traceback -from fate.arch.utils import trace +from fate.arch import trace from .argparser import HfArgumentParser logger = logging.getLogger(__name__) @@ -129,9 +129,9 @@ def _run_process( ): sys.argv = argv args = HfArgumentParser(LauncherProcessArguments).parse_args_into_dataclasses(return_remaining_strings=True)[0] - from fate.arch.utils.logger import set_up_logging + from fate.arch.launchers.logger import set_up_logging from fate.arch.launchers.context_helper import init_context - from fate.arch.utils.trace import setup_tracing + from fate.arch.trace import setup_tracing from fate.arch.computing.api import profile_start, profile_ends if args.rank >= len(args.parties): @@ -202,7 +202,7 @@ def show_exceptions(self): self.console.print(rich.panel.Panel(tb, title=f"rank {rank} exception", expand=False, border_style="red")) def block_run(self, f): - from fate.arch.utils.trace import setup_tracing + from fate.arch.trace import setup_tracing setup_tracing("multi_process_launcher") with trace.get_tracer(__name__).start_as_current_span(self.federation_session_id): @@ -267,7 +267,7 @@ def launch(f, **kwargs): args_desc.extend(kwargs.get("extra_args_desc", [])) args, _ = HfArgumentParser(args_desc).parse_known_args(namespace=namespace) - from fate.arch.utils.logger import set_up_logging + from fate.arch.launchers.logger import set_up_logging set_up_logging(-1, args.log_level) diff --git a/python/fate/arch/utils/paths.py b/python/fate/arch/launchers/paths.py similarity index 100% rename from python/fate/arch/utils/paths.py rename to python/fate/arch/launchers/paths.py diff --git a/python/fate/arch/protocol/mpc/communicator/communicator.py b/python/fate/arch/protocol/mpc/communicator/communicator.py index 84f63f0e34..f12580f537 100644 --- a/python/fate/arch/protocol/mpc/communicator/communicator.py +++ b/python/fate/arch/protocol/mpc/communicator/communicator.py @@ -9,7 +9,7 @@ from torch.distributed import ReduceOp from fate.arch.context import Context, NS, Parties -from fate.arch.utils import trace +from fate.arch import trace from typing import List logger = logging.getLogger(__name__) diff --git a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py index e9650a0700..35651e2ed3 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py @@ -5,7 +5,7 @@ from fate.arch.context import Context from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings from fate.arch.protocol.mpc.mpc import FixedPointEncoder -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace class SSHELinearRegressionLayer: diff --git a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py index ca8f56436b..26fdff3af5 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py @@ -5,7 +5,7 @@ from fate.arch.context import Context from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings from fate.arch.protocol.mpc.mpc import FixedPointEncoder -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace class SSHELogisticRegressionLayer: diff --git a/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py index bdeeb150b4..94e1e3a7e8 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py @@ -7,7 +7,7 @@ from fate.arch.context import Context from fate.arch.protocol.mpc.common.encoding import IgnoreEncodings from fate.arch.protocol.mpc.mpc import FixedPointEncoder -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace class SSHENeuralNetworkAggregatorLayer(torch.nn.Module): diff --git a/python/fate/arch/protocol/mpc/primitives/sshe.py b/python/fate/arch/protocol/mpc/primitives/sshe.py index 21435de733..f6cb37badb 100644 --- a/python/fate/arch/protocol/mpc/primitives/sshe.py +++ b/python/fate/arch/protocol/mpc/primitives/sshe.py @@ -7,7 +7,7 @@ from fate.arch.protocol.mpc.common.rng import generate_random_ring_element from fate.arch.protocol.mpc.primitives.arithmetic import ArithmeticSharedTensor from fate.arch.protocol.mpc.primitives.beaver import IgnoreEncodings -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace if typing.TYPE_CHECKING: diff --git a/python/fate/arch/tensor/distributed/_tensor.py b/python/fate/arch/tensor/distributed/_tensor.py index 303e7e0978..60a8267667 100644 --- a/python/fate/arch/tensor/distributed/_tensor.py +++ b/python/fate/arch/tensor/distributed/_tensor.py @@ -21,7 +21,7 @@ from fate.arch.computing.api import KVTable from fate.arch.context import Context from fate.arch.tensor.phe import PHETensor -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace _HANDLED_FUNCTIONS = {} diff --git a/python/fate/arch/tensor/phe/_tensor.py b/python/fate/arch/tensor/phe/_tensor.py index d440963f17..aaff06c6bc 100644 --- a/python/fate/arch/tensor/phe/_tensor.py +++ b/python/fate/arch/tensor/phe/_tensor.py @@ -2,7 +2,7 @@ import torch -from fate.arch.utils.trace import auto_trace +from fate.arch.trace import auto_trace _HANDLED_FUNCTIONS = {} _PHE_TENSOR_ENCODED_HANDLED_FUNCTIONS = {} diff --git a/python/fate/arch/trace/__init__.py b/python/fate/arch/trace/__init__.py new file mode 100644 index 0000000000..ea38c0758a --- /dev/null +++ b/python/fate/arch/trace/__init__.py @@ -0,0 +1,14 @@ +from ._trace import ( + get_tracer, + auto_trace, + federation_pull_bytes_trace, + federation_push_table_trace, + federation_pull_table_trace, + federation_push_bytes_trace, + federation_auto_trace, + setup_tracing, + inject_carrier, + StatusCode, + extract_carrier, + instrument_thread_pool_executor, +) diff --git a/python/fate/arch/utils/trace.py b/python/fate/arch/trace/_trace.py similarity index 97% rename from python/fate/arch/utils/trace.py rename to python/fate/arch/trace/_trace.py index 924d3e31e4..e423ae5763 100644 --- a/python/fate/arch/utils/trace.py +++ b/python/fate/arch/trace/_trace.py @@ -253,4 +253,9 @@ def instrument_thread_pool_executor(executor): "get_tracer", "federation_auto_trace", "StatusCode", + "instrument_thread_pool_executor", + "federation_pull_bytes_trace", + "federation_pull_table_trace", + "federation_push_bytes_trace", + "federation_push_table_trace", ] diff --git a/python/fate/arch/unify/__init__.py b/python/fate/arch/unify/__init__.py index 56e19cb27b..95b56ddec8 100644 --- a/python/fate/arch/unify/__init__.py +++ b/python/fate/arch/unify/__init__.py @@ -12,12 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._infra_def import Backend, device +from ._infra_def import device from ._io import URI from ._uuid import uuid __all__ = [ - "Backend", "device", "uuid", "URI", diff --git a/python/fate/arch/unify/_infra_def.py b/python/fate/arch/unify/_infra_def.py index d95fd3f602..d3299e014d 100644 --- a/python/fate/arch/unify/_infra_def.py +++ b/python/fate/arch/unify/_infra_def.py @@ -41,9 +41,3 @@ def to_torch_device(self): return torch.device("cuda", self.index) else: raise ValueError(f"device type {self.type} not supported") - - -class Backend(Enum): - STANDALONE = "STANDALONE" - EGGROLL = "EGGROLL" - SPARK = "SPARK" diff --git a/python/fate/arch/utils/__init__.py b/python/fate/arch/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From bf9e8a684e33af570f199f235f1c54e5713b04e9 Mon Sep 17 00:00:00 2001 From: sagewe Date: Mon, 11 Dec 2023 11:15:02 +0800 Subject: [PATCH 27/42] fix tracing default setting Signed-off-by: sagewe --- python/fate/arch/trace/_trace.py | 8 ++++++-- python/setup.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/fate/arch/trace/_trace.py b/python/fate/arch/trace/_trace.py index e423ae5763..f6d7222fd0 100644 --- a/python/fate/arch/trace/_trace.py +++ b/python/fate/arch/trace/_trace.py @@ -13,18 +13,22 @@ logger = logging.getLogger(__name__) _ENABLE_TRACING = None -_ENABLE_TRACING_DEFAULT = True +_ENABLE_TRACING_DEFAULT = False def _is_tracing_enabled(): global _ENABLE_TRACING if _ENABLE_TRACING is None: - _ENABLE_TRACING = os.environ.get("FATE_ENABLE_TRACING", str(_ENABLE_TRACING_DEFAULT)).lower() == "false" + if (env_setting := os.environ.get("FATE_ENABLE_TRACING")) is not None: + _ENABLE_TRACING = bool(env_setting) + else: + _ENABLE_TRACING = _ENABLE_TRACING_DEFAULT return _ENABLE_TRACING def setup_tracing(service_name, endpoint: str = None): if not _is_tracing_enabled(): + logger.info("disabled tracing") return from opentelemetry.sdk.resources import SERVICE_NAME, Resource diff --git a/python/setup.py b/python/setup.py index 0d1d4dc884..01b4c7b466 100644 --- a/python/setup.py +++ b/python/setup.py @@ -22,6 +22,8 @@ "requests", "grpcio", "protobuf", + "opentelemetry-api", + "opentelemetry-sdk", ] # Extra requirements From 98a72f013888901f83e8ddee850396c66681b765 Mon Sep 17 00:00:00 2001 From: sagewe Date: Mon, 11 Dec 2023 18:48:12 +0800 Subject: [PATCH 28/42] fix bug Signed-off-by: sagewe --- python/fate/arch/computing/api/_profile.py | 5 ++++- python/fate/arch/context/_context.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/fate/arch/computing/api/_profile.py b/python/fate/arch/computing/api/_profile.py index bb8ac6b584..f7e7301928 100644 --- a/python/fate/arch/computing/api/_profile.py +++ b/python/fate/arch/computing/api/_profile.py @@ -339,7 +339,10 @@ def _call_stack_strings(): return call_stack_strings -def computing_profile(func): +T = typing.TypeVar("T", bound=typing.Callable) + + +def computing_profile(func: T) -> T: @wraps(func) def _fn(*args, **kwargs): function_call_stack = _call_stack_strings() diff --git a/python/fate/arch/context/_context.py b/python/fate/arch/context/_context.py index 375a2a0399..993ea7e764 100644 --- a/python/fate/arch/context/_context.py +++ b/python/fate/arch/context/_context.py @@ -111,7 +111,7 @@ def with_namespace(self, namespace: NS): ) @property - def computing(self): + def computing(self) -> "KVTableContext": return self._get_computing() @property @@ -259,7 +259,7 @@ def _get_federation(self): raise RuntimeError(f"federation not set") return self._federation - def _get_computing(self): + def _get_computing(self) -> "KVTableContext": if self._computing is None: raise RuntimeError(f"computing not set") return self._computing From 2809d8b59f98faa7f361676a7e74ed3fb4ea9550 Mon Sep 17 00:00:00 2001 From: zhihuiwan <15779896112@163.com> Date: Tue, 12 Dec 2023 10:34:08 +0800 Subject: [PATCH 29/42] fix component of reader Signed-off-by: zhihuiwan <15779896112@163.com> --- .../core/component_desc/artifacts/data/_unresolved.py | 1 + python/fate/components/core/spec/component.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/fate/components/core/component_desc/artifacts/data/_unresolved.py b/python/fate/components/core/component_desc/artifacts/data/_unresolved.py index 23ca24f9b0..2c02820ddd 100644 --- a/python/fate/components/core/component_desc/artifacts/data/_unresolved.py +++ b/python/fate/components/core/component_desc/artifacts/data/_unresolved.py @@ -14,6 +14,7 @@ class DataUnresolvedWriter(_ArtifactTypeWriter[DataUnresolvedArtifactType]): def write_metadata(self, metadata: dict, name=None, namespace=None): + self.artifact.consumed() self.artifact.metadata.metadata.update(metadata) if name is not None: self.artifact.metadata.name = name diff --git a/python/fate/components/core/spec/component.py b/python/fate/components/core/spec/component.py index c802833bbc..a53dad7c81 100644 --- a/python/fate/components/core/spec/component.py +++ b/python/fate/components/core/spec/component.py @@ -84,7 +84,7 @@ class ComponentSpecV1(BaseModel): class ArtifactTypeSpec(BaseModel): type_name: str uri_types: List[str] - path_type: Literal["file", "directory", "distributed"] + path_type: Literal["file", "directory", "distributed", "unresolved"] class ComponentIOArtifactTypeSpec(BaseModel): From 6bfa7ef68f2d72256f5e596afd698da21f547dc2 Mon Sep 17 00:00:00 2001 From: sagewe Date: Tue, 12 Dec 2023 11:56:08 +0800 Subject: [PATCH 30/42] add session info Signed-off-by: sagewe --- python/fate/arch/computing/api/_table.py | 6 ++++++ python/fate/arch/computing/backends/eggroll/_csession.py | 8 ++++++++ .../fate/arch/computing/backends/standalone/_csession.py | 7 +++++++ .../arch/computing/backends/standalone/_standalone.py | 9 +++++++++ 4 files changed, 30 insertions(+) diff --git a/python/fate/arch/computing/api/_table.py b/python/fate/arch/computing/api/_table.py index 60f1ce67a3..720561b232 100644 --- a/python/fate/arch/computing/api/_table.py +++ b/python/fate/arch/computing/api/_table.py @@ -82,6 +82,12 @@ def _parallelize( ): raise NotImplementedError(f"{self.__class__.__name__}._parallelize") + def info(self): + return self._info() + + def _info(self): + return {} + def load(self, uri: URI, schema: dict, options: dict = None): return self._load( uri=uri, diff --git a/python/fate/arch/computing/backends/eggroll/_csession.py b/python/fate/arch/computing/backends/eggroll/_csession.py index ddc0bb56fd..8de1542c16 100644 --- a/python/fate/arch/computing/backends/eggroll/_csession.py +++ b/python/fate/arch/computing/backends/eggroll/_csession.py @@ -110,6 +110,14 @@ def _parallelize( ) return Table(rp) + def _info(self): + if hasattr(self._rpc, "info"): + return self._rpc.info() + else: + return { + "session_id": self.session_id, + } + def cleanup(self, name, namespace): self._rpc.cleanup(name=name, namespace=namespace) diff --git a/python/fate/arch/computing/backends/standalone/_csession.py b/python/fate/arch/computing/backends/standalone/_csession.py index 7beee710a0..37424630cb 100644 --- a/python/fate/arch/computing/backends/standalone/_csession.py +++ b/python/fate/arch/computing/backends/standalone/_csession.py @@ -106,6 +106,13 @@ def _parallelize( ) return Table(table) + def _info(self): + return { + "session_id": self.session_id, + "data_dir": self._session.data_dir, + "max_workers": self._session.max_workers, + } + def cleanup(self, name, namespace): return self._session.cleanup(name=name, namespace=namespace) diff --git a/python/fate/arch/computing/backends/standalone/_standalone.py b/python/fate/arch/computing/backends/standalone/_standalone.py index 32e29b838b..f8710bb296 100644 --- a/python/fate/arch/computing/backends/standalone/_standalone.py +++ b/python/fate/arch/computing/backends/standalone/_standalone.py @@ -445,6 +445,7 @@ class Session(object): def __init__(self, session_id, data_dir: str, max_workers=None, logger_config=None): self.session_id = session_id self._data_dir = data_dir + self._max_workers = max_workers self._pool = Executor( max_workers=max_workers, initializer=_watch_thread_react_to_parent_die, @@ -455,6 +456,14 @@ def __init__(self, session_id, data_dir: str, max_workers=None, logger_config=No ) self._enable_process_logger = True + @property + def data_dir(self): + return self._data_dir + + @property + def max_workers(self): + return self._max_workers + def __getstate__(self): # session won't be pickled pass From d4e13a8d2cf4e5d5ffd0cbd1c981c93a5381746e Mon Sep 17 00:00:00 2001 From: sagewe Date: Tue, 12 Dec 2023 13:21:32 +0800 Subject: [PATCH 31/42] fix federation logger Signed-off-by: sagewe --- python/fate/arch/federation/message_queue/_federation.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/fate/arch/federation/message_queue/_federation.py b/python/fate/arch/federation/message_queue/_federation.py index 5ddb041857..a2670f5036 100644 --- a/python/fate/arch/federation/message_queue/_federation.py +++ b/python/fate/arch/federation/message_queue/_federation.py @@ -565,13 +565,7 @@ def _receive_obj(self, channel_info, name, tag): # channel_info = self._query_receive_topic(channel_info) for id, properties, body in self._get_consume_message(channel_info): - LOGGER.debug(f"[federation._receive_obj] properties: {properties}") - if properties["message_id"] != name or properties["correlation_id"] != tag: - # todo: fix this - LOGGER.warning( - f"[federation._receive_obj] require {name}.{tag}, got {properties['message_id']}.{properties['correlation_id']}" - ) - + LOGGER.debug(f"properties: {properties}") cache_key = self._get_message_cache_key( properties["message_id"], properties["correlation_id"], party_id, role ) From 041ec844e97b7562c4542ab6755bd36419bba9db Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 12 Dec 2023 14:39:03 +0800 Subject: [PATCH 32/42] always turn off reveal every epoch for SSHE LR & LinR(#5227) edit examples(#5227) Signed-off-by: Yu Wu --- .../benchmark_quality/linr/fate-sshe-linr.py | 120 ++++++++++++++++++ .../benchmark_quality/lr/breast_config.yaml | 3 +- .../lr/default_credit_config.yaml | 3 +- .../lr/epsilon_5k_sshe_config.yaml | 12 ++ .../lr/give_credit_config.yaml | 3 +- .../benchmark_quality/lr/lr_benchmark.yaml | 15 +++ .../benchmark_quality/lr/pipeline-lr-multi.py | 11 +- .../lr/pipeline-sshe-lr-binary.py | 119 +++++++++++++++++ .../lr/pipeline-sshe-lr-multi.py | 116 +++++++++++++++++ .../benchmark_quality/lr/vehicle_config.yaml | 3 +- examples/pipeline/sshe_linr/test_linr_cv.py | 2 +- .../sshe_linr/test_linr_warm_start.py | 2 +- examples/pipeline/sshe_lr/test_lr.py | 2 +- .../pipeline/sshe_lr/test_lr_multi_class.py | 4 +- .../sshe_lr/test_lr_predict_w_torch.py | 1 + examples/pipeline/sshe_lr/test_lr_validate.py | 2 +- .../pipeline/sshe_lr/test_lr_warm_start.py | 4 +- .../fate/components/components/sshe_linr.py | 10 +- python/fate/components/components/sshe_lr.py | 10 +- python/fate/ml/glm/hetero/sshe/sshe_linr.py | 39 ++++-- python/fate/ml/glm/hetero/sshe/sshe_lr.py | 42 ++++-- 21 files changed, 476 insertions(+), 47 deletions(-) create mode 100644 examples/benchmark_quality/linr/fate-sshe-linr.py create mode 100644 examples/benchmark_quality/lr/epsilon_5k_sshe_config.yaml create mode 100644 examples/benchmark_quality/lr/pipeline-sshe-lr-binary.py create mode 100644 examples/benchmark_quality/lr/pipeline-sshe-lr-multi.py diff --git a/examples/benchmark_quality/linr/fate-sshe-linr.py b/examples/benchmark_quality/linr/fate-sshe-linr.py new file mode 100644 index 0000000000..c53264790e --- /dev/null +++ b/examples/benchmark_quality/linr/fate-sshe-linr.py @@ -0,0 +1,120 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELinR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + +from fate_test.utils import parse_summary_result + + +def main(config="../../config.yaml", param="./linr_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_train_data = {"name": "motor_hetero_guest", "namespace": f"experiment{namespace}"} + host_train_data = {"name": "motor_hetero_host", "namespace": f"experiment{namespace}"} + + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + linr_param = { + } + + config_param = { + "epochs": param["epochs"], + # "learning_rate_scheduler": param["learning_rate_scheduler"], + # "optimizer": param["optimizer"], + "learning_rate": param["learning_rate"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5 + } + linr_param.update(config_param) + linr_0 = SSHELinR("linr_0", + train_data=psi_0.outputs["output_data"], + **config_param) + """linr_1 = SSHELinR("linr_1", + test_data=psi_0.outputs["output_data"], + input_model=linr_0.outputs["output_model"])""" + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + metrics=["r2_score", + "mse", + "rmse"], + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(linr_0) + # pipeline.add_task(linr_1) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.compile() + print(pipeline.get_dag()) + pipeline.fit() + + """linr_0_data = pipeline.get_task_info("linr_0").get_output_data()["train_output_data"] + linr_1_data = pipeline.get_task_info("linr_1").get_output_data()["test_output_data"] + linr_0_score = extract_data(linr_0_data, "predict_result") + linr_0_label = extract_data(linr_0_data, "motor_speed") + linr_1_score = extract_data(linr_1_data, "predict_result") + linr_1_label = extract_data(linr_1_data, "motor_speed") + linr_0_score_label = extract_data(linr_0_data, "predict_result", keep_id=True) + linr_1_score_label = extract_data(linr_1_data, "predict_result", keep_id=True)""" + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + print(f"result_summary") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/breast_config.yaml b/examples/benchmark_quality/lr/breast_config.yaml index 3d1747cc04..0a65aed268 100644 --- a/examples/benchmark_quality/lr/breast_config.yaml +++ b/examples/benchmark_quality/lr/breast_config.yaml @@ -19,4 +19,5 @@ optimizer: lr: 0.15 alpha: 0.01 batch_size: 240 -early_stop: "diff" \ No newline at end of file +early_stop: "diff" +learning_rate: 0.15 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/default_credit_config.yaml b/examples/benchmark_quality/lr/default_credit_config.yaml index 3bfbd67760..5b321f2513 100644 --- a/examples/benchmark_quality/lr/default_credit_config.yaml +++ b/examples/benchmark_quality/lr/default_credit_config.yaml @@ -19,4 +19,5 @@ optimizer: optimizer_params: lr: 0.12 batch_size: 10000 -early_stop: "diff" \ No newline at end of file +early_stop: "diff" +learning_rate: 0.15 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/epsilon_5k_sshe_config.yaml b/examples/benchmark_quality/lr/epsilon_5k_sshe_config.yaml new file mode 100644 index 0000000000..98ab848020 --- /dev/null +++ b/examples/benchmark_quality/lr/epsilon_5k_sshe_config.yaml @@ -0,0 +1,12 @@ +data_guest: "epsilon_5k_hetero_guest" +data_host: "epsilon_5k_hetero_host" +idx: "id" +label_name: "y" +epochs: 20 +batch_size: 3000 +init_param: + fit_intercept: True + method: "random" + random_state: 42 +early_stop: "diff" +learning_rate: 0.34 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/give_credit_config.yaml b/examples/benchmark_quality/lr/give_credit_config.yaml index 6f8656132b..94c2cda8fa 100644 --- a/examples/benchmark_quality/lr/give_credit_config.yaml +++ b/examples/benchmark_quality/lr/give_credit_config.yaml @@ -18,4 +18,5 @@ optimizer: optimizer_params: lr: 0.25 batch_size: null -early_stop: "diff" \ No newline at end of file +early_stop: "diff" +learning_rate: 0.25 \ No newline at end of file diff --git a/examples/benchmark_quality/lr/lr_benchmark.yaml b/examples/benchmark_quality/lr/lr_benchmark.yaml index 294c9264c1..530f31e6a1 100644 --- a/examples/benchmark_quality/lr/lr_benchmark.yaml +++ b/examples/benchmark_quality/lr/lr_benchmark.yaml @@ -176,6 +176,9 @@ hetero_lr-binary-0-breast: FATE-hetero-lr: script: "./pipeline-lr-binary.py" conf: "./breast_config.yaml" + FATE-hetero-sshe-lr: + script: "./pipeline-sshe-lr-binary.py" + conf: "./breast_config.yaml" compare_setting: relative_tol: 0.01 hetero_lr-binary-1-default-credit: @@ -185,6 +188,9 @@ hetero_lr-binary-1-default-credit: FATE-hetero-lr: script: "./pipeline-lr-binary.py" conf: "./default_credit_config.yaml" + FATE-hetero-sshe-lr: + script: "./pipeline-sshe-lr-binary.py" + conf: "./default_credit_config.yaml" compare_setting: relative_tol: 0.01 hetero_lr-binary-2-epsilon-5k: @@ -194,6 +200,9 @@ hetero_lr-binary-2-epsilon-5k: FATE-hetero-lr: script: "./pipeline-lr-binary.py" conf: "./epsilon_5k_config.yaml" + FATE-hetero-sshe-lr: + script: "./pipeline-sshe-lr-binary.py" + conf: "./epsilon_5k_sshe_config.yaml" compare_setting: relative_tol: 0.01 hetero_lr-binary-3-give-credit: @@ -203,6 +212,9 @@ hetero_lr-binary-3-give-credit: FATE-hetero-lr: script: "./pipeline-lr-binary.py" conf: "./give_credit_config.yaml" + FATE-hetero-sshe-lr: + script: "./pipeline-sshe-lr-binary.py" + conf: "./give_credit_config.yaml" compare_setting: relative_tol: 0.01 multi-vehicle: @@ -212,6 +224,9 @@ multi-vehicle: FATE-hetero-lr: script: "./pipeline-lr-multi.py" conf: "./vehicle_config.yaml" + FATE-hetero-sshe-lr: + script: "./pipeline-sshe-lr-multi.py" + conf: "./vehicle_config.yaml" compare_setting: relative_tol: 0.01 diff --git a/examples/benchmark_quality/lr/pipeline-lr-multi.py b/examples/benchmark_quality/lr/pipeline-lr-multi.py index 5f7fd9bfb2..e10259e8af 100644 --- a/examples/benchmark_quality/lr/pipeline-lr-multi.py +++ b/examples/benchmark_quality/lr/pipeline-lr-multi.py @@ -21,7 +21,8 @@ from fate_client.pipeline.components.fate import Evaluation from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils -from fate_test.utils import extract_data, parse_summary_result + +from fate_test.utils import parse_summary_result def main(config="../../config.yaml", param="./vehicle_config.yaml", namespace=""): @@ -87,12 +88,12 @@ def main(config="../../config.yaml", param="./vehicle_config.yaml", namespace="" pipeline.compile() pipeline.fit() - lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] - lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] + # lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] + # lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) - lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) - lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True) + # lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) + #lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True) data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} diff --git a/examples/benchmark_quality/lr/pipeline-sshe-lr-binary.py b/examples/benchmark_quality/lr/pipeline-sshe-lr-binary.py new file mode 100644 index 0000000000..8db45cc8d2 --- /dev/null +++ b/examples/benchmark_quality/lr/pipeline-sshe-lr-binary.py @@ -0,0 +1,119 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + +from fate_test.utils import parse_summary_result + + +def main(config="../../config.yaml", param="./breast_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + lr_param = { + } + + config_param = { + "epochs": param["epochs"], + # "learning_rate_scheduler": param["learning_rate_scheduler"], + # "optimizer": param["optimizer"], + "learning_rate": param["learning_rate"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "tol": 1e-5, + "reveal_loss_freq": 3, + } + lr_param.update(config_param) + lr_0 = SSHELR("lr_0", + train_data=psi_0.outputs["output_data"], + **lr_param) + lr_1 = SSHELR("lr_1", + test_data=psi_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) + + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + metrics=["auc", "binary_precision", "binary_accuracy", "binary_recall"], + input_data=lr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(evaluation_0) + + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + pipeline.compile() + pipeline.fit() + + """lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] + lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] + lr_0_score = extract_data(lr_0_data, "predict_result") + lr_0_label = extract_data(lr_0_data, "y") + lr_1_score = extract_data(lr_1_data, "predict_result") + lr_1_label = extract_data(lr_1_data, "y") + lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) + lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True)""" + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + print(f"result_summary: {result_summary}") + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./breast_config.yaml") + args = parser.parse_args() + main(args.config, args.param) diff --git a/examples/benchmark_quality/lr/pipeline-sshe-lr-multi.py b/examples/benchmark_quality/lr/pipeline-sshe-lr-multi.py new file mode 100644 index 0000000000..083991aff9 --- /dev/null +++ b/examples/benchmark_quality/lr/pipeline-sshe-lr-multi.py @@ -0,0 +1,116 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import SSHELR, PSI +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + +from fate_test.utils import parse_summary_result + + +def main(config="../../config.yaml", param="./vehicle_config.yaml", namespace=""): + # obtain config + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + + if isinstance(param, str): + param = test_utils.JobConfig.load_from_file(param) + + assert isinstance(param, dict) + guest_data_table = param.get("data_guest") + host_data_table = param.get("data_host") + + guest_train_data = {"name": guest_data_table, "namespace": f"experiment{namespace}"} + host_train_data = {"name": host_data_table, "namespace": f"experiment{namespace}"} + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name=guest_train_data["name"], + namespace=guest_train_data["namespace"])) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name=host_train_data["name"], + namespace=host_train_data["namespace"])) + + lr_param = { + } + + config_param = { + "epochs": param["epochs"], + # "learning_rate_scheduler": param["learning_rate_scheduler"], + # "optimizer": param["optimizer"], + "learning_rate": param["learning_rate"], + "batch_size": param["batch_size"], + "early_stop": param["early_stop"], + "init_param": param["init_param"], + "reveal_loss_freq": 3, + "tol": 1e-5, + } + lr_param.update(config_param) + lr_0 = SSHELR("lr_0", + train_data=psi_0.outputs["output_data"], + **config_param) + lr_1 = SSHELR("lr_1", + test_data=psi_0.outputs["output_data"], + input_model=lr_0.outputs["output_model"]) + + evaluation_0 = Evaluation('evaluation_0', + runtime_roles=['guest'], + input_data=lr_0.outputs["train_output_data"], + predict_column_name='predict_result', + metrics=['multi_recall', 'multi_accuracy', 'multi_precision']) + pipeline.add_task(psi_0) + pipeline.add_task(lr_0) + pipeline.add_task(lr_1) + pipeline.add_task(evaluation_0) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + pipeline.compile() + pipeline.fit() + + # lr_0_data = pipeline.get_task_info("lr_0").get_output_data()["train_output_data"] + # lr_1_data = pipeline.get_task_info("lr_1").get_output_data()["test_output_data"] + + result_summary = parse_summary_result(pipeline.get_task_info("evaluation_0").get_output_metric()[0]["data"]) + # lr_0_score_label = extract_data(lr_0_data, "predict_result", keep_id=True) + # lr_1_score_label = extract_data(lr_1_data, "predict_result", keep_id=True) + + data_summary = {"train": {"guest": guest_train_data["name"], "host": host_train_data["name"]}, + "test": {"guest": guest_train_data["name"], "host": host_train_data["name"]} + } + return data_summary, result_summary + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("BENCHMARK-QUALITY PIPELINE JOB") + parser.add_argument("-c", "--config", type=str, + help="config file", default="../../config.yaml") + parser.add_argument("-p", "--param", type=str, + help="config file for params", default="./vehicle_config.yaml") + + args = parser.parse_args() + if args.config is not None: + main(args.config, args.param) + else: + main() diff --git a/examples/benchmark_quality/lr/vehicle_config.yaml b/examples/benchmark_quality/lr/vehicle_config.yaml index fdbfadf47d..d519dc31bd 100644 --- a/examples/benchmark_quality/lr/vehicle_config.yaml +++ b/examples/benchmark_quality/lr/vehicle_config.yaml @@ -21,4 +21,5 @@ optimizer: batch_size: 18 early_stop: "diff" task_cores: null -timeout: 3600 \ No newline at end of file +timeout: 3600 +learning_rate: 0.3 \ No newline at end of file diff --git a/examples/pipeline/sshe_linr/test_linr_cv.py b/examples/pipeline/sshe_linr/test_linr_cv.py index 5b61fef983..7b4f31cd30 100644 --- a/examples/pipeline/sshe_linr/test_linr_cv.py +++ b/examples/pipeline/sshe_linr/test_linr_cv.py @@ -45,7 +45,7 @@ def main(config="../config.yaml", namespace=""): init_param={"fit_intercept": True}, cv_data=psi_0.outputs["output_data"], cv_param={"n_splits": 3}, - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) diff --git a/examples/pipeline/sshe_linr/test_linr_warm_start.py b/examples/pipeline/sshe_linr/test_linr_warm_start.py index 0e9a5a302e..d62f1546ed 100644 --- a/examples/pipeline/sshe_linr/test_linr_warm_start.py +++ b/examples/pipeline/sshe_linr/test_linr_warm_start.py @@ -64,7 +64,7 @@ def main(config="../config.yaml", namespace=""): init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], learning_rate=0.05, - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) diff --git a/examples/pipeline/sshe_lr/test_lr.py b/examples/pipeline/sshe_lr/test_lr.py index 4a5f47263f..ed375881fc 100644 --- a/examples/pipeline/sshe_lr/test_lr.py +++ b/examples/pipeline/sshe_lr/test_lr.py @@ -46,7 +46,7 @@ def main(config="../config.yaml", namespace=""): batch_size=300, init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) diff --git a/examples/pipeline/sshe_lr/test_lr_multi_class.py b/examples/pipeline/sshe_lr/test_lr_multi_class.py index 5a64db054f..c739a86631 100644 --- a/examples/pipeline/sshe_lr/test_lr_multi_class.py +++ b/examples/pipeline/sshe_lr/test_lr_multi_class.py @@ -44,8 +44,8 @@ def main(config="../config.yaml", namespace=""): learning_rate=0.15, epochs=10, batch_size=None, - reveal_every_epoch=True, - early_stop="weight_diff", + reveal_every_epoch=False, + early_stop="diff", reveal_loss_freq=1, init_param={"fit_intercept": True, "method": "random_uniform"}, train_data=psi_0.outputs["output_data"]) diff --git a/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py index eb8a67382f..da3d2ba7a6 100644 --- a/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py +++ b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py @@ -56,6 +56,7 @@ def main(config="../config.yaml", namespace=""): batch_size=300, init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], + reveal_every_epoch=False ) evaluation_0 = Evaluation("evaluation_0", diff --git a/examples/pipeline/sshe_lr/test_lr_validate.py b/examples/pipeline/sshe_lr/test_lr_validate.py index e1bcc5d154..29e1c3ef9e 100644 --- a/examples/pipeline/sshe_lr/test_lr_validate.py +++ b/examples/pipeline/sshe_lr/test_lr_validate.py @@ -47,7 +47,7 @@ def main(config="../config.yaml", namespace=""): lr_0 = SSHELR("lr_0", epochs=10, batch_size=300, - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, init_param={"fit_intercept": True, "method": "random_uniform"}, diff --git a/examples/pipeline/sshe_lr/test_lr_warm_start.py b/examples/pipeline/sshe_lr/test_lr_warm_start.py index 55cdce9ebf..c17e75722e 100644 --- a/examples/pipeline/sshe_lr/test_lr_warm_start.py +++ b/examples/pipeline/sshe_lr/test_lr_warm_start.py @@ -54,7 +54,7 @@ def main(config="../config.yaml", namespace=""): epochs=2, batch_size=None, learning_rate=0.05, - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) @@ -64,7 +64,7 @@ def main(config="../config.yaml", namespace=""): learning_rate=0.05, init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) diff --git a/python/fate/components/components/sshe_linr.py b/python/fate/components/components/sshe_linr.py index eeb546f887..ec19781f1f 100644 --- a/python/fate/components/components/sshe_linr.py +++ b/python/fate/components/components/sshe_linr.py @@ -44,12 +44,13 @@ def train( early_stop: cpn.parameter( type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", - desc="early stopping criterion, choose from {weight_diff, diff, abs}", + desc="early stopping criterion, choose from {weight_diff, diff, abs}, if use weight_diff," + "weight will be revealed every epoch", ), learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), reveal_every_epoch: cpn.parameter(type=bool, default=False, desc="whether reveal encrypted result every epoch, " - "if False, only reveal at the end of training"), + "only accept False for now"), init_param: cpn.parameter( type=params.init_param(), default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), @@ -116,7 +117,8 @@ def cross_validation( early_stop: cpn.parameter( type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", - desc="early stopping criterion, choose from {weight_diff, diff, abs}", + desc="early stopping criterion, choose from {weight_diff, diff, abs}, if use weight_diff," + "weight will be revealed every epoch", ), learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), init_param: cpn.parameter( @@ -129,7 +131,7 @@ def cross_validation( ), reveal_every_epoch: cpn.parameter(type=bool, default=False, desc="whether reveal encrypted result every epoch, " - "if False, only reveal at the end of training"), + "only accept False for now"), reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, desc="rounds to reveal training loss, " "only effective if `early_stop` is 'loss'"), diff --git a/python/fate/components/components/sshe_lr.py b/python/fate/components/components/sshe_lr.py index f7cc04a5a7..692134ee0e 100644 --- a/python/fate/components/components/sshe_lr.py +++ b/python/fate/components/components/sshe_lr.py @@ -44,12 +44,13 @@ def train( early_stop: cpn.parameter( type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", - desc="early stopping criterion, choose from {weight_diff, diff, abs}", + desc="early stopping criterion, choose from {weight_diff, diff, abs}, if use weight_diff," + "weight will be revealed every epoch", ), learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), reveal_every_epoch: cpn.parameter(type=bool, default=False, desc="whether reveal encrypted result every epoch, " - "if False, only reveal at the end of training"), + "only accept False for now"), init_param: cpn.parameter( type=params.init_param(), default=params.InitParam(method="random_uniform", fit_intercept=True, random_state=None), @@ -115,7 +116,8 @@ def cross_validation( early_stop: cpn.parameter( type=params.string_choice(["weight_diff", "diff", "abs"]), default="diff", - desc="early stopping criterion, choose from {weight_diff, diff, abs}", + desc="early stopping criterion, choose from {weight_diff, diff, abs}, if use weight_diff," + "weight will be revealed every epoch", ), learning_rate: cpn.parameter(type=params.confloat(ge=0), default=0.05, desc="learning rate"), init_param: cpn.parameter( @@ -128,7 +130,7 @@ def cross_validation( ), reveal_every_epoch: cpn.parameter(type=bool, default=False, desc="whether reveal encrypted result every epoch, " - "if False, only reveal at the end of training"), + "only accept False for now"), reveal_loss_freq: cpn.parameter(type=params.conint(ge=1), default=1, desc="rounds to reveal training loss, " "only effective if `early_stop` is 'loss'"), diff --git a/python/fate/ml/glm/hetero/sshe/sshe_linr.py b/python/fate/ml/glm/hetero/sshe/sshe_linr.py index 5303bf0d8f..a6e1f4c11e 100644 --- a/python/fate/ml/glm/hetero/sshe/sshe_linr.py +++ b/python/fate/ml/glm/hetero/sshe/sshe_linr.py @@ -45,6 +45,8 @@ def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_para self.learning_rate = learning_rate self.init_param = init_param self.threshold = threshold + if reveal_every_epoch: + raise ValueError(f"reveal_every_epoch is currenly not supported in SSHELogisticRegression") self.reveal_every_epoch = reveal_every_epoch self.reveal_loss_freq = reveal_loss_freq @@ -183,17 +185,24 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data else: batch_loader = dataframe.DataLoader( train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + # if self.reveal_every_epoch: if self.early_stop == "weight_diff": + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) if ctx.is_on_guest: - self.converge_func.set_pre_weight(wb.get_plain_text(dst=rank_b)) + self.converge_func.set_pre_weight(wb_p) else: - self.converge_func.set_pre_weight(wa.get_plain_text(dst=rank_a)) + self.converge_func.set_pre_weight(wa_p) for i, epoch_ctx in ctx.on_iterations.ctxs_range(self.epochs): epoch_loss = None logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in epoch_ctx.on_batches.ctxs_zip(batch_loader): h = batch_data.x y = batch_ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) + """if self.reveal_every_epoch: + z = batch_ctx.mpc.cond_call(lambda: torch.matmul(h, wa_p.detach()), + lambda: torch.matmul(h, wb_p.detach()), dst=rank_a) + else:""" z = layer(h) loss = loss_fn(z, y) if i % self.reveal_loss_freq == 0: @@ -205,12 +214,12 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data optimizer.step() if epoch_loss is not None: epoch_ctx.metrics.log_loss("linr_loss", epoch_loss.tolist()) - if self.reveal_every_epoch: - wa_p = wa.get_plain_text(dst=rank_a) - wb_p = wb.get_plain_text(dst=rank_b) + # if self.reveal_every_epoch: + # wa_p = wa.get_plain_text(dst=rank_a) + # wb_p = wb.get_plain_text(dst=rank_b) if ctx.is_on_guest: if self.early_stop == "weight_diff": - if self.reveal_every_epoch: + """if self.reveal_every_epoch: wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] self.converge_func.set_pre_weight(wb_p) @@ -218,20 +227,32 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data self.is_converged = True else: raise ValueError(f"early stop {self.early_stop} is not supported when " - f"reveal_every_epoch is False") + f"reveal_every_epoch is False")""" + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) + w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] + self.converge_func.set_pre_weight(wb_p) + if w_diff < self.tol: + self.is_converged = True else: if i % self.reveal_loss_freq == 0: self.is_converged = self.converge_func.is_converge(epoch_loss) epoch_ctx.hosts.put("converge_flag", self.is_converged) else: if self.early_stop == "weight_diff": - if self.reveal_every_epoch: + """if self.reveal_every_epoch: wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) epoch_ctx.guest.put("wa_p_delta", wa_p_delta) self.converge_func.set_pre_weight(wa_p) else: raise ValueError(f"early stop {self.early_stop} is not supported when " - f"reveal_every_epoch is False") + f"reveal_every_epoch is False")""" + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) + epoch_ctx.guest.put("wa_p_delta", wa_p_delta) + self.converge_func.set_pre_weight(wa_p) self.is_converged = epoch_ctx.guest.get("converge_flag") if self.is_converged: self.end_epoch = i diff --git a/python/fate/ml/glm/hetero/sshe/sshe_lr.py b/python/fate/ml/glm/hetero/sshe/sshe_lr.py index 330da786a7..e4d7c3defb 100644 --- a/python/fate/ml/glm/hetero/sshe/sshe_lr.py +++ b/python/fate/ml/glm/hetero/sshe/sshe_lr.py @@ -44,6 +44,8 @@ def __init__(self, epochs, batch_size, tol, early_stop, learning_rate, init_para self.learning_rate = learning_rate self.init_param = init_param self.threshold = threshold + if reveal_every_epoch: + raise ValueError(f"reveal_every_epoch is currenly not supported in SSHELogisticRegression") self.reveal_every_epoch = reveal_every_epoch self.reveal_loss_freq = reveal_loss_freq @@ -231,8 +233,6 @@ def predict(self, ctx, test_data) -> DataFrame: self.estimator.predict(ctx, test_data) - - class SSHELREstimator(HeteroModule): def __init__(self, epochs=None, batch_size=None, optimizer=None, learning_rate=None, init_param=None, reveal_every_epoch=True, reveal_loss_freq=3, early_stop=None, tol=None): @@ -285,17 +285,24 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data else: batch_loader = dataframe.DataLoader( train_data, ctx=ctx, batch_size=self.batch_size, mode="hetero", role="host") + # if self.reveal_every_epoch: if self.early_stop == "weight_diff": + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) if ctx.is_on_guest: - self.converge_func.set_pre_weight(wb.get_plain_text(dst=rank_b)) + self.converge_func.set_pre_weight(wb_p) else: - self.converge_func.set_pre_weight(wa.get_plain_text(dst=rank_a)) + self.converge_func.set_pre_weight(wa_p) for i, epoch_ctx in ctx.on_iterations.ctxs_range(self.epochs): epoch_loss = None logger.info(f"self.optimizer set epoch {i}") for batch_ctx, batch_data in epoch_ctx.on_batches.ctxs_zip(batch_loader): h = batch_data.x y = batch_ctx.mpc.cond_call(lambda: batch_data.label, lambda: None, dst=rank_b) + # if self.reveal_every_epoch: + # z = batch_ctx.mpc.cond_call(lambda: torch.matmul(h, wa_p.detach()), + # lambda: torch.matmul(h, wb_p.detach()), dst=rank_a) + # else: z = layer(h) loss = loss_fn(z, y) if i % self.reveal_loss_freq == 0: @@ -307,12 +314,12 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data optimizer.step() if epoch_loss is not None: epoch_ctx.metrics.log_loss("lr_loss", epoch_loss.tolist()) - if self.reveal_every_epoch: - wa_p = wa.get_plain_text(dst=rank_a) - wb_p = wb.get_plain_text(dst=rank_b) + # if self.reveal_every_epoch: + # wa_p = wa.get_plain_text(dst=rank_a) + # wb_p = wb.get_plain_text(dst=rank_b) if ctx.is_on_guest: if self.early_stop == "weight_diff": - if self.reveal_every_epoch: + """if self.reveal_every_epoch: wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] self.converge_func.set_pre_weight(wb_p) @@ -320,17 +327,26 @@ def fit_single_model(self, ctx: Context, train_data: DataFrame, valid_data: Data self.is_converged = True else: raise ValueError(f"early stop {self.early_stop} is not supported when " - f"reveal_every_epoch is False") + f"reveal_every_epoch is False")""" + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + wb_p_delta = self.converge_func.compute_weight_diff(wb_p - self.converge_func.pre_weight) + w_diff = wb_p_delta + epoch_ctx.hosts.get("wa_p_delta")[0] + self.converge_func.set_pre_weight(wb_p) + if w_diff < self.tol: + self.is_converged = True else: if i % self.reveal_loss_freq == 0: self.is_converged = self.converge_func.is_converge(epoch_loss) epoch_ctx.hosts.put("converge_flag", self.is_converged) else: if self.early_stop == "weight_diff": - if self.reveal_every_epoch: - wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) - epoch_ctx.guest.put("wa_p_delta", wa_p_delta) - self.converge_func.set_pre_weight(wa_p) + # if self.reveal_every_epoch: + wa_p = wa.get_plain_text(dst=rank_a) + wb_p = wb.get_plain_text(dst=rank_b) + wa_p_delta = self.converge_func.compute_weight_diff(wa_p - self.converge_func.pre_weight) + epoch_ctx.guest.put("wa_p_delta", wa_p_delta) + self.converge_func.set_pre_weight(wa_p) self.is_converged = epoch_ctx.guest.get("converge_flag") if self.is_converged: self.end_epoch = i From 765ec9f638074f0f6a09088cd201beb1d852e410 Mon Sep 17 00:00:00 2001 From: Yu Wu Date: Tue, 12 Dec 2023 20:13:25 +0800 Subject: [PATCH 33/42] edit examples(#5227) Signed-off-by: Yu Wu --- .../linr/hetero_linr_benchmark.yaml | 3 + .../benchmark_quality/linr/linr_config.yaml | 1 + .../sshe_linr/sshe_linr_testsuite.yaml | 2 + .../pipeline/sshe_linr/test_linr_validate.py | 94 +++++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 examples/pipeline/sshe_linr/test_linr_validate.py diff --git a/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml b/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml index d3089cc26b..844d8ebcd5 100644 --- a/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml +++ b/examples/benchmark_quality/linr/hetero_linr_benchmark.yaml @@ -41,5 +41,8 @@ hetero_linr: FATE-hetero-linr: script: "./fate-linr.py" conf: "./linr_config.yaml" + FATE-hetero-sshe-linr: + script: "./fate-sshe-linr.py" + conf: "./linr_config.yaml" compare_setting: relative_tol: 0.01 diff --git a/examples/benchmark_quality/linr/linr_config.yaml b/examples/benchmark_quality/linr/linr_config.yaml index 13f5199e90..3b072de4ea 100644 --- a/examples/benchmark_quality/linr/linr_config.yaml +++ b/examples/benchmark_quality/linr/linr_config.yaml @@ -20,3 +20,4 @@ optimizer: alpha: 0.01 batch_size: 100 early_stop: "diff" +learning_rate: 0.13 diff --git a/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml b/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml index 14366bff98..ce1b008610 100644 --- a/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml +++ b/examples/pipeline/sshe_linr/sshe_linr_testsuite.yaml @@ -40,3 +40,5 @@ tasks: script: test_linr_cv.py linr-warm-start: script: test_linr_warm_start.py + linr-validate: + script: test_linr_validate.py \ No newline at end of file diff --git a/examples/pipeline/sshe_linr/test_linr_validate.py b/examples/pipeline/sshe_linr/test_linr_validate.py new file mode 100644 index 0000000000..bb76fe1b56 --- /dev/null +++ b/examples/pipeline/sshe_linr/test_linr_validate.py @@ -0,0 +1,94 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from fate_client.pipeline import FateFlowPipeline +from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation, DataSplit +from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.utils import test_utils + + +def main(config="../config.yaml", namespace=""): + if isinstance(config, str): + config = test_utils.load_job_config(config) + parties = config.parties + guest = parties.guest[0] + host = parties.host[0] + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) + if config.task_cores: + pipeline.conf.set("task_cores", config.task_cores) + if config.timeout: + pipeline.conf.set("timeout", config.timeout) + + psi_0 = PSI("psi_0") + psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + data_split_0 = DataSplit("data_split_0", + train_size=0.8, + validate_size=0.2, + input_data=psi_0.outputs["output_data"]) + linr_0 = SSHELinR("linr_0", + epochs=2, + batch_size=100, + init_param={"fit_intercept": True}, + train_data=data_split_0.outputs["train_output_data"], + validate_data=data_split_0.outputs["validate_output_data"], + reveal_every_epoch=False, + early_stop="diff", + reveal_loss_freq=1, + learning_rate=0.1) + evaluation_0 = Evaluation("evaluation_0", + runtime_roles=["guest"], + default_eval_setting="regression", + input_data=linr_0.outputs["train_output_data"]) + + pipeline.add_task(psi_0) + pipeline.add_task(data_split_0) + pipeline.add_task(linr_0) + pipeline.add_task(evaluation_0) + pipeline.compile() + # print(pipeline.get_dag()) + pipeline.fit() + + pipeline.deploy([psi_0, linr_0]) + + predict_pipeline = FateFlowPipeline() + + deployed_pipeline = pipeline.get_deployed_pipeline() + deployed_pipeline.psi_0.guest.task_setting( + input_data=DataWarehouseChannel(name="motor_hetero_guest", + namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.hosts[0].task_setting( + input_data=DataWarehouseChannel(name="motor_hetero_host", + namespace=f"experiment{namespace}")) + + predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.compile() + # print("\n\n\n") + # print(predict_pipeline.compile().get_dag()) + predict_pipeline.predict() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("PIPELINE DEMO") + parser.add_argument("--config", type=str, default="../config.yaml", + help="config file") + parser.add_argument("--namespace", type=str, default="", + help="namespace for data stored in FATE") + args = parser.parse_args() + main(config=args.config, namespace=args.namespace) From e71ee18fd1c5daf11967ed5b03f55cabd2400db6 Mon Sep 17 00:00:00 2001 From: sagewe Date: Tue, 12 Dec 2023 23:02:51 +0800 Subject: [PATCH 34/42] add default config for protocols Signed-off-by: sagewe --- configs/default.yaml | 20 +++++++++++++++++++- python/fate/arch/context/_cipher.py | 21 +++++++++++++++++++-- python/fate/arch/protocol/psi/_psi_run.py | 5 +++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index 438828fef2..f0e303ba99 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -49,4 +49,22 @@ nn: safety: serdes: # supported types: unrestricted, restricted, restricted_catch_miss - restricted_type: "unrestricted" \ No newline at end of file + restricted_type: "unrestricted" + + phe: + paillier: + allow: True + minimum_key_size: 1024 + + ou: + allow: True + minimum_key_size: 1024 + + mock: + allow: False + + psi: + ecdh: + allow: True + curve_type: + - curve25519 diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index d496a3fa26..cbde240952 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -16,6 +16,7 @@ import logging import typing +from fate.arch.config import cfg from ..unify import device if typing.TYPE_CHECKING: @@ -49,11 +50,14 @@ def _set_default_phe(self): self._cipher_mapping["phe"] = {} if self._device not in self._cipher_mapping["phe"]: if self._device == device.CPU: - self._cipher_mapping["phe"][device.CPU] = {"kind": "paillier", "key_length": 1024} + self._cipher_mapping["phe"][device.CPU] = { + "kind": "paillier", + "key_length": cfg.safety.phe.paillier.minimum_key_size, + } else: logger.warning(f"no impl exists for device {self._device}, fallback to CPU") self._cipher_mapping["phe"][device.CPU] = self._cipher_mapping["phe"].get( - device.CPU, {"kind": "paillier", "key_length": 1024} + device.CPU, {"kind": "paillier", "key_length": cfg.safety.phe.paillier.minimum_key_size} ) @property @@ -105,14 +109,25 @@ def setup(self, options: typing.Optional[dict] = None): key_size = options.get("key_length", 1024) if kind == "paillier": + if not cfg.safety.phe.paillier.allow: + raise ValueError("paillier is not allowed in config") + if key_size < cfg.safety.phe.paillier.minimum_key_size: + raise ValueError( + f"key size {key_size} is too small, minimum is {cfg.safety.phe.paillier.minimum_key_size}" + ) from fate.arch.protocol.phe.paillier import evaluator, keygen from fate.arch.tensor.phe import PHETensorCipher sk, pk, coder = keygen(key_size) tensor_cipher = PHETensorCipher.from_raw_cipher(pk, coder, sk, evaluator) + return PHECipher(kind, key_size, pk, sk, evaluator, coder, tensor_cipher, True, True, True) if kind == "ou": + if not cfg.safety.phe.ou.allow: + raise ValueError("ou is not allowed in config") + if key_size < cfg.safety.phe.ou.minimum_key_size: + raise ValueError(f"key size {key_size} is too small, minimum is {cfg.safety.phe.ou.minimum_key_size}") from fate.arch.protocol.phe.ou import evaluator, keygen from fate.arch.tensor.phe import PHETensorCipher @@ -121,6 +136,8 @@ def setup(self, options: typing.Optional[dict] = None): return PHECipher(kind, key_size, pk, sk, evaluator, coder, tensor_cipher, False, False, True) elif kind == "mock": + if not cfg.safety.phe.mock.allow: + raise ValueError("mock is not allowed in config") from fate.arch.protocol.phe.mock import evaluator, keygen from fate.arch.tensor.phe import PHETensorCipher diff --git a/python/fate/arch/protocol/psi/_psi_run.py b/python/fate/arch/protocol/psi/_psi_run.py index a53fa27726..1daf5b1591 100644 --- a/python/fate/arch/protocol/psi/_psi_run.py +++ b/python/fate/arch/protocol/psi/_psi_run.py @@ -13,11 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from fate.arch.config import cfg from .ecdh._run import psi_ecdh def psi_run(ctx, df, protocol="ecdh_psi", curve_type="curve25519"): if protocol == "ecdh_psi": + if not cfg.safety.psi.ecdh.allow: + raise ValueError("ecdh psi is not allowed in config") + if curve_type not in cfg.safety.psi.ecdh.curve_type: + raise ValueError(f"curve_type={curve_type} is not allowed in config") return psi_ecdh(ctx, df, curve_type=curve_type) else: raise ValueError(f"PSI protocol={protocol} does not implemented yet.") From 827f726c8363cca89325f11fb87a8cb56e182c1e Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 13 Dec 2023 11:40:22 +0800 Subject: [PATCH 35/42] update examples/pipeline Signed-off-by: mgqa34 --- examples/pipeline/config.yaml | 8 +-- .../pipeline/coordinated_linr/test_linr.py | 30 ++++----- .../pipeline/coordinated_linr/test_linr_cv.py | 15 ++--- .../coordinated_linr/test_linr_multi_host.py | 35 ++++------ .../coordinated_linr/test_linr_warm_start.py | 22 +++---- examples/pipeline/coordinated_lr/test_lr.py | 42 +++++++----- .../pipeline/coordinated_lr/test_lr_cv.py | 21 +++--- .../coordinated_lr/test_lr_multi_class.py | 45 +++++++------ .../coordinated_lr/test_lr_multi_host.py | 44 +++++++------ .../coordinated_lr/test_lr_predict_w_torch.py | 24 +++---- .../coordinated_lr/test_lr_validate.py | 25 +++---- .../coordinated_lr/test_lr_warm_start.py | 26 ++++---- .../pipeline/data_split/test_data_split.py | 40 ++++++----- .../data_split/test_data_split_multi_host.py | 46 +++++++------ .../data_split/test_data_split_stratified.py | 39 ++++++----- .../test_feature_correlation.py | 22 ++++--- .../feature_scale/test_scale_min_max.py | 59 ++++++++++------- .../feature_scale/test_scale_standard.py | 58 +++++++++------- .../pipeline/feature_scale/test_scale_w_lr.py | 62 ++++++++++------- .../test_feature_binning_asymmetric.py | 43 +++++++----- .../test_feature_binning_bucket.py | 57 +++++++++------- .../test_feature_binning_multi_host.py | 66 +++++++++++-------- .../test_feature_binning_quantile.py | 46 +++++++------ .../test_feature_selection_binning.py | 40 ++++++----- .../test_feature_selection_manual.py | 42 +++++++----- .../test_feature_selection_multi_host.py | 44 +++++++------ .../test_feature_selection_multi_model.py | 41 +++++++----- .../test_feature_selection_statistics.py | 39 ++++++----- examples/pipeline/hetero_nn/test_nn_binary.py | 29 ++++---- .../hetero_nn/test_nn_binary_with_fedpass.py | 29 ++++---- .../test_hetero_sbt_binary.py | 27 ++++---- .../test_hetero_sbt_binary_cv.py | 24 +++---- .../test_hetero_sbt_multi.py | 27 ++++---- .../test_hetero_sbt_regression.py | 27 ++++---- .../pipeline/homo_lr/test_homo_lr_binary.py | 22 ++++--- .../homo_lr/test_homo_lr_multi_ovr.py | 21 +++--- examples/pipeline/homo_nn/test_nn_binary.py | 33 +++++----- examples/pipeline/homo_nn/test_nn_multi.py | 29 ++++---- .../pipeline/homo_nn/test_nn_regression.py | 28 ++++---- examples/pipeline/multi_model/test_multi.py | 58 ++++++++-------- .../multi_model/test_multi_preprocessing.py | 47 ++++++------- .../multi_model/test_multi_w_predict.py | 45 +++++++------ examples/pipeline/sample/test_sample.py | 39 ++++++----- .../pipeline/sample/test_sample_multi_host.py | 43 ++++++------ .../pipeline/sample/test_sample_unilateral.py | 42 +++++++----- examples/pipeline/sshe_linr/test_linr.py | 31 ++++----- examples/pipeline/sshe_linr/test_linr_cv.py | 15 ++--- .../pipeline/sshe_linr/test_linr_validate.py | 32 ++++----- .../sshe_linr/test_linr_warm_start.py | 22 +++---- examples/pipeline/sshe_lr/test_lr.py | 45 +++++++------ examples/pipeline/sshe_lr/test_lr_cv.py | 21 +++--- .../pipeline/sshe_lr/test_lr_multi_class.py | 46 +++++++------ .../sshe_lr/test_lr_predict_w_torch.py | 24 +++---- examples/pipeline/sshe_lr/test_lr_validate.py | 25 +++---- .../pipeline/sshe_lr/test_lr_warm_start.py | 28 ++++---- .../pipeline/statistics/test_statistics.py | 21 +++--- .../statistics/test_statistics_default.py | 21 +++--- examples/pipeline/union/test_union.py | 40 ++++++----- 58 files changed, 1102 insertions(+), 920 deletions(-) diff --git a/examples/pipeline/config.yaml b/examples/pipeline/config.yaml index 394a5b7802..33b39cc4da 100644 --- a/examples/pipeline/config.yaml +++ b/examples/pipeline/config.yaml @@ -1,10 +1,10 @@ parties: # parties default id guest: - - 9999 + - '9999' host: - - 9998 - - 9999 + - '9998' + - '9999' arbiter: - - 9998 + - '9998' data_base_dir: "" # path to project base where data is located \ No newline at end of file diff --git a/examples/pipeline/coordinated_linr/test_linr.py b/examples/pipeline/coordinated_linr/test_linr.py index 02d76508c7..314cdcce7e 100644 --- a/examples/pipeline/coordinated_linr/test_linr.py +++ b/examples/pipeline/coordinated_linr/test_linr.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Evaluation, Reader from fate_client.pipeline.utils import test_utils @@ -34,11 +33,10 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = CoordinatedLinR("linr_0", epochs=10, batch_size=100, @@ -47,13 +45,11 @@ def main(config="../config.yaml", namespace=""): init_param={"fit_intercept": True}, train_data=psi_0.outputs["output_data"]) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=linr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, linr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() @@ -62,15 +58,13 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_1.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/coordinated_linr/test_linr_cv.py b/examples/pipeline/coordinated_linr/test_linr_cv.py index 7c9189bc20..7b7c9e5a73 100644 --- a/examples/pipeline/coordinated_linr/test_linr_cv.py +++ b/examples/pipeline/coordinated_linr/test_linr_cv.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLinR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -34,11 +33,10 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = CoordinatedLinR("linr_0", epochs=10, batch_size=None, @@ -48,8 +46,7 @@ def main(config="../config.yaml", namespace=""): cv_data=psi_0.outputs["output_data"], cv_param={"n_splits": 3}) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) + pipeline.add_tasks([reader_0, psi_0, linr_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/coordinated_linr/test_linr_multi_host.py b/examples/pipeline/coordinated_linr/test_linr_multi_host.py index 5a3789622c..0813f9b89a 100644 --- a/examples/pipeline/coordinated_linr/test_linr_multi_host.py +++ b/examples/pipeline/coordinated_linr/test_linr_multi_host.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -32,13 +31,10 @@ def main(config="../config.yaml", namespace=""): pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"{namespace}experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"{namespace}experiment")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"{namespace}experiment")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[[0, 1]].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = CoordinatedLinR("linr_0", epochs=5, batch_size=None, @@ -51,13 +47,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=linr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, linr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) @@ -67,15 +61,14 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_1.hosts[[0, 1]].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"{namespace}experiment")) - deployed_pipeline.psi_0.hosts[[0, 1]].task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"{namespace}experiment")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/coordinated_linr/test_linr_warm_start.py b/examples/pipeline/coordinated_linr/test_linr_warm_start.py index dcb97524da..f93e7314eb 100644 --- a/examples/pipeline/coordinated_linr/test_linr_warm_start.py +++ b/examples/pipeline/coordinated_linr/test_linr_warm_start.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLinR, PSI +from fate_client.pipeline.components.fate import CoordinatedLinR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -35,11 +34,11 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = CoordinatedLinR("linr_0", epochs=4, batch_size=None, @@ -67,16 +66,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=[linr_1.outputs["train_output_data"], linr_2.outputs["train_output_data"]]) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) - pipeline.add_task(linr_1) - pipeline.add_task(linr_2) - pipeline.add_task(evaluation_0) - + pipeline.add_tasks([reader_0, psi_0, linr_0, linr_1, linr_2, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/coordinated_lr/test_lr.py b/examples/pipeline/coordinated_lr/test_lr.py index 51c9fc6217..a5b567bd9b 100644 --- a/examples/pipeline/coordinated_lr/test_lr.py +++ b/examples/pipeline/coordinated_lr/test_lr.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -36,11 +35,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=10, batch_size=300, @@ -51,13 +55,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() pipeline.fit() @@ -66,15 +68,19 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/coordinated_lr/test_lr_cv.py b/examples/pipeline/coordinated_lr/test_lr_cv.py index 42400a63fd..e483df376b 100644 --- a/examples/pipeline/coordinated_lr/test_lr_cv.py +++ b/examples/pipeline/coordinated_lr/test_lr_cv.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -34,11 +33,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=2, batch_size=None, @@ -48,8 +52,7 @@ def main(config="../config.yaml", namespace=""): cv_data=psi_0.outputs["output_data"], cv_param={"n_splits": 3}) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) + pipeline.add_tasks([reader_0, psi_0, lr_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/coordinated_lr/test_lr_multi_class.py b/examples/pipeline/coordinated_lr/test_lr_multi_class.py index c1a9319a8d..e558fb8e4d 100644 --- a/examples/pipeline/coordinated_lr/test_lr_multi_class.py +++ b/examples/pipeline/coordinated_lr/test_lr_multi_class.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -36,11 +35,17 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=10, batch_size=None, @@ -52,14 +57,12 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="multi", predict_column_name='predict_result', input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) @@ -69,15 +72,19 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/coordinated_lr/test_lr_multi_host.py b/examples/pipeline/coordinated_lr/test_lr_multi_host.py index 580614b863..32bce75253 100644 --- a/examples/pipeline/coordinated_lr/test_lr_multi_host.py +++ b/examples/pipeline/coordinated_lr/test_lr_multi_host.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -32,13 +31,16 @@ def main(config="../config.yaml", namespace=""): pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"{namespace}experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"{namespace}experiment")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"{namespace}experiment")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=5, batch_size=None, @@ -51,13 +53,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) @@ -67,15 +67,19 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"{namespace}experiment")) - deployed_pipeline.psi_0.hosts[[0, 1]].task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"{namespace}experiment")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/coordinated_lr/test_lr_predict_w_torch.py b/examples/pipeline/coordinated_lr/test_lr_predict_w_torch.py index 4f697195ce..d83f40833c 100644 --- a/examples/pipeline/coordinated_lr/test_lr_predict_w_torch.py +++ b/examples/pipeline/coordinated_lr/test_lr_predict_w_torch.py @@ -17,9 +17,8 @@ import torch from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -47,11 +46,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=10, batch_size=300, @@ -62,13 +66,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/coordinated_lr/test_lr_validate.py b/examples/pipeline/coordinated_lr/test_lr_validate.py index f378daa522..6f2e774a81 100644 --- a/examples/pipeline/coordinated_lr/test_lr_validate.py +++ b/examples/pipeline/coordinated_lr/test_lr_validate.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI, DataSplit +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, DataSplit, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -36,11 +35,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.8, validate_size=0.2, @@ -57,14 +61,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(data_split_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, data_split_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/coordinated_lr/test_lr_warm_start.py b/examples/pipeline/coordinated_lr/test_lr_warm_start.py index de328cd1f5..9fa8e6eba0 100644 --- a/examples/pipeline/coordinated_lr/test_lr_warm_start.py +++ b/examples/pipeline/coordinated_lr/test_lr_warm_start.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, Reader from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -35,11 +34,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = CoordinatedLR("lr_0", epochs=4, batch_size=None, @@ -65,15 +69,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=[lr_1.outputs["train_output_data"], lr_2.outputs["train_output_data"]]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(lr_1) - pipeline.add_task(lr_2) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, lr_1, lr_2, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/data_split/test_data_split.py b/examples/pipeline/data_split/test_data_split.py index dcc67359c4..b1bb999143 100644 --- a/examples/pipeline/data_split/test_data_split.py +++ b/examples/pipeline/data_split/test_data_split.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import DataSplit, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import DataSplit, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -34,17 +33,27 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, @@ -56,13 +65,10 @@ def main(config="../config.yaml", namespace=""): data_split_1 = DataSplit("data_split_1", train_size=200, test_size=50, - input_data=psi_0.outputs["output_data"] + input_data=psi_1.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(data_split_0) - pipeline.add_task(data_split_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/data_split/test_data_split_multi_host.py b/examples/pipeline/data_split/test_data_split_multi_host.py index 1d5694e696..55b09f6dc7 100644 --- a/examples/pipeline/data_split/test_data_split_multi_host.py +++ b/examples/pipeline/data_split/test_data_split_multi_host.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import DataSplit, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import DataSplit, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -34,20 +33,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, validate_size=0.1, @@ -58,13 +65,10 @@ def main(config="../config.yaml", namespace=""): data_split_1 = DataSplit("data_split_1", train_size=200, test_size=50, - input_data=psi_0.outputs["output_data"] + input_data=psi_1.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(data_split_0) - pipeline.add_task(data_split_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/data_split/test_data_split_stratified.py b/examples/pipeline/data_split/test_data_split_stratified.py index cc0ca8f4fd..e26024f191 100644 --- a/examples/pipeline/data_split/test_data_split_stratified.py +++ b/examples/pipeline/data_split/test_data_split_stratified.py @@ -16,7 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import DataSplit, PSI +from fate_client.pipeline.components.fate import DataSplit, PSI, Reader from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -34,17 +34,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, @@ -61,11 +72,7 @@ def main(config="../config.yaml", namespace=""): input_data=psi_1.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(data_split_0) - pipeline.add_task(data_split_1) - + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/feature_correlation/test_feature_correlation.py b/examples/pipeline/feature_correlation/test_feature_correlation.py index 05d07b89b1..708b7ef44d 100644 --- a/examples/pipeline/feature_correlation/test_feature_correlation.py +++ b/examples/pipeline/feature_correlation/test_feature_correlation.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, FeatureCorrelation -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, FeatureCorrelation, Reader from fate_client.pipeline.utils import test_utils @@ -33,17 +32,22 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) feature_corr_0 = FeatureCorrelation("feature_corr_0", input_data=psi_0.outputs["output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(feature_corr_0) + pipeline.add_tasks([reader_0, psi_0, feature_corr_0]) pipeline.compile() print(pipeline.get_dag()) diff --git a/examples/pipeline/feature_scale/test_scale_min_max.py b/examples/pipeline/feature_scale/test_scale_min_max.py index 99991a5fcc..bbda4e0f8f 100644 --- a/examples/pipeline/feature_scale/test_scale_min_max.py +++ b/examples/pipeline/feature_scale/test_scale_min_max.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, FeatureScale, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, FeatureScale, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,17 +32,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) feature_scale_0 = FeatureScale("feature_scale_0", method="min_max", @@ -59,11 +69,7 @@ def main(config="../config.yaml", namespace=""): metrics=["max", "min", "mean", "std"], input_data=feature_scale_1.outputs["test_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(feature_scale_0) - pipeline.add_task(feature_scale_1) - pipeline.add_task(statistics_0) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, feature_scale_0, feature_scale_1, statistics_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -76,13 +82,20 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_2 = Reader("reader_2") + reader_2.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_2.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_2.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_2, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/feature_scale/test_scale_standard.py b/examples/pipeline/feature_scale/test_scale_standard.py index d5127b4430..05632d8739 100644 --- a/examples/pipeline/feature_scale/test_scale_standard.py +++ b/examples/pipeline/feature_scale/test_scale_standard.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, FeatureScale -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, FeatureScale, Reader from fate_client.pipeline.utils import test_utils @@ -34,17 +33,27 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) feature_scale_0 = FeatureScale("feature_scale_0", method="standard", @@ -54,10 +63,7 @@ def main(config="../config.yaml", namespace=""): test_data=psi_1.outputs["output_data"], input_model=feature_scale_0.outputs["output_model"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(feature_scale_0) - pipeline.add_task(feature_scale_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, feature_scale_0, feature_scale_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -71,13 +77,21 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_2 = Reader("reader_2", runtime_parties=dict(guest=guest, host=host)) + reader_2.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_2.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_2.outputs["output_data"] + + predict_pipeline.add_tasks([reader_2, deployed_pipeline]) - predict_pipeline.add_task(deployed_pipeline) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/feature_scale/test_scale_w_lr.py b/examples/pipeline/feature_scale/test_scale_w_lr.py index fb96ea266c..a9e266f6f6 100644 --- a/examples/pipeline/feature_scale/test_scale_w_lr.py +++ b/examples/pipeline/feature_scale/test_scale_w_lr.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import CoordinatedLR, PSI, FeatureScale, Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import CoordinatedLR, PSI, FeatureScale, Evaluation, Reader from fate_client.pipeline.utils import test_utils @@ -34,17 +33,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_0.outputs["output_data"]) feature_scale_0 = FeatureScale("feature_scale_0", method="standard", @@ -60,15 +70,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(feature_scale_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, feature_scale_0, lr_0, evaluation_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -79,13 +85,21 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_2 = Reader("reader_2", runtime_parties=dict(guest=guest, host=host)) + reader_2.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_2.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_2.outputs["output_data"] + + predict_pipeline.add_tasks([reader_2, deployed_pipeline]) - predict_pipeline.add_task(deployed_pipeline) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py index 83648cada9..12d398123e 100644 --- a/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_asymmetric.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", @@ -45,17 +49,15 @@ def main(config="../config.yaml", namespace=""): train_data=psi_0.outputs["output_data"], local_only=True ) - binning_0.guest.task_setting(bin_col=["x0"], transform_method="bin_idx") + binning_0.guest.task_parameters(bin_col=["x0"], transform_method="bin_idx") binning_1 = HeteroFeatureBinning("binning_1", transform_method="bin_idx", method="quantile", train_data=binning_0.outputs["train_output_data"]) - binning_1.guest.task_setting(category_col=["x0"], transform_method="woe") + binning_1.guest.task_parameters(category_col=["x0"], transform_method="woe") - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(binning_1) + pipeline.add_tasks([reader_0, psi_0, binning_0, binning_1]) pipeline.compile() # print(pipeline.get_dag()) @@ -65,13 +67,20 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py index 6036603db9..a732b95743 100644 --- a/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_bucket.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning, Reader from fate_client.pipeline.utils import test_utils @@ -33,17 +32,27 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="bucket", @@ -57,10 +66,7 @@ def main(config="../config.yaml", namespace=""): input_model=binning_0.outputs["output_model"], test_data=psi_1.outputs["output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(binning_0) - pipeline.add_task(binning_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, binning_0, binning_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -74,13 +80,20 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_2 = Reader("reader_2") + reader_2.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_2.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_2.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_2, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py index 2a1a802109..f63856b09e 100644 --- a/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_multi_host.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning, Reader from fate_client.pipeline.utils import test_utils @@ -33,20 +32,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) + binning_0 = HeteroFeatureBinning("binning_0", method="bucket", n_bins=10, @@ -59,10 +66,7 @@ def main(config="../config.yaml", namespace=""): input_model=binning_0.outputs["output_model"], test_data=psi_1.outputs["output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(binning_0) - pipeline.add_task(binning_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, binning_0, binning_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -75,15 +79,21 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_2 = Reader("reader_2") + reader_2.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_2.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_2.outputs["output_data"] + + predict_pipeline.add_tasks([reader_2, deployed_pipeline]) + predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py b/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py index 2998142140..0a27451475 100644 --- a/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py +++ b/examples/pipeline/hetero_feature_binning/test_feature_binning_quantile.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureBinning, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", @@ -45,18 +49,16 @@ def main(config="../config.yaml", namespace=""): transform_method="bin_idx", train_data=psi_0.outputs["output_data"] ) - binning_0.hosts[0].task_setting(bin_idx=[1]) - binning_0.guest.task_setting(bin_col=["x0"]) + binning_0.hosts[0].task_parameters(bin_idx=[1]) + binning_0.guest.task_parameters(bin_col=["x0"]) binning_1 = HeteroFeatureBinning("binning_1", transform_method="bin_idx", method="quantile", train_data=binning_0.outputs["train_output_data"]) - binning_1.hosts[0].task_setting(category_idx=[1]) - binning_1.guest.task_setting(category_col=["x0"]) + binning_1.hosts[0].task_parameters(category_idx=[1]) + binning_1.guest.task_parameters(category_col=["x0"]) - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(binning_1) + pipeline.add_tasks([reader_0, psi_0, binning_0, binning_1]) pipeline.compile() # print(pipeline.get_dag()) @@ -66,13 +68,19 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py index 1012a2d737..40728f4075 100644 --- a/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_binning.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,16 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", @@ -51,9 +55,7 @@ def main(config=".../config.yaml", namespace=""): input_models=[binning_0.outputs["output_model"]], iv_param={"metrics": "iv", "filter_type": "threshold", "threshold": 0.1}) - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(selection_0) + pipeline.add_tasks([reader_0, psi_0, binning_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -66,13 +68,21 @@ def main(config=".../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) - predict_pipeline.add_task(deployed_pipeline) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py index 467f09cdf4..c8b56bfefd 100644 --- a/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_manual.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, Reader from fate_client.pipeline.utils import test_utils @@ -33,20 +32,24 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) selection_0 = HeteroFeatureSelection("selection_0", method=["manual"], train_data=psi_0.outputs["output_data"]) - selection_0.guest.task_setting(manual_param={"keep_col": ["x0", "x1"]}) - selection_0.hosts[0].task_setting(manual_param={"filter_out_col": ["x0", "x1"]}) + selection_0.guest.task_parameters(manual_param={"keep_col": ["x0", "x1"]}) + selection_0.hosts[0].task_parameters(manual_param={"filter_out_col": ["x0", "x1"]}) - pipeline.add_task(psi_0) - pipeline.add_task(selection_0) + pipeline.add_tasks([reader_0, psi_0, selection_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -59,13 +62,20 @@ def main(config=".../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py index 062a256f0c..468e14423e 100644 --- a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_host.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,13 +32,16 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", n_bins=10, @@ -58,10 +60,7 @@ def main(config=".../config.yaml", namespace=""): manual_param={"keep_col": ["x0", "x1"]} ) - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(statistics_0) - pipeline.add_task(selection_0) + pipeline.add_tasks([reader_0, psi_0, binning_0, statistics_0, selection_0]) pipeline.compile() pipeline.fit() @@ -72,15 +71,20 @@ def main(config=".../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py index 1f70c75d80..7249a7e642 100644 --- a/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_multi_model.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,12 +32,16 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", n_bins=10, @@ -57,11 +60,7 @@ def main(config=".../config.yaml", namespace=""): manual_param={"keep_col": ["x0", "x1"]} ) - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(statistics_0) - pipeline.add_task(selection_0) - + pipeline.add_tasks([reader_0, psi_0, binning_0, statistics_0, selection_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() # print(pipeline.get_dag()) @@ -72,14 +71,20 @@ def main(config=".../config.yaml", namespace=""): pipeline.deploy([psi_0, selection_0]) predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py b/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py index 2ad69ea04a..e74ec84739 100644 --- a/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py +++ b/examples/pipeline/hetero_feature_selection/test_feature_selection_statistics.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,12 +32,16 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"], metrics=["min", "max", "25%", "mean", "median"]) selection_0 = HeteroFeatureSelection("selection_0", @@ -48,9 +51,7 @@ def main(config=".../config.yaml", namespace=""): statistic_param={"metrics": ["max", "mean", "25%"], "filter_type": "top_k", "threshold": 5}) - pipeline.add_task(psi_0) - pipeline.add_task(statistics_0) - pipeline.add_task(selection_0) + pipeline.add_tasks([reader_0, psi_0, statistics_0, selection_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -62,14 +63,20 @@ def main(config=".../config.yaml", namespace=""): pipeline.deploy([psi_0, selection_0]) predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/hetero_nn/test_nn_binary.py b/examples/pipeline/hetero_nn/test_nn_binary.py index 8423e96bae..74b402e2b1 100644 --- a/examples/pipeline/hetero_nn/test_nn_binary.py +++ b/examples/pipeline/hetero_nn/test_nn_binary.py @@ -15,13 +15,12 @@ # import argparse -from fate_test.utils import parse_summary_result from fate_client.pipeline.utils import test_utils from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.components.fate.nn.torch import nn, optim from fate_client.pipeline.components.fate.nn.torch.base import Sequential from fate_client.pipeline.components.fate.hetero_nn import HeteroNN, get_config_of_default_runner +from fate_client.pipeline.components.fate.reader import Reader from fate_client.pipeline.components.fate.psi import PSI from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments from fate_client.pipeline.components.fate import Evaluation @@ -38,11 +37,16 @@ def main(config="../../config.yaml", namespace=""): pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) training_args = TrainingArguments( num_train_epochs=5, @@ -73,8 +77,8 @@ def main(config="../../config.yaml", namespace=""): train_data=psi_0.outputs['output_data'], validate_data=psi_0.outputs['output_data'] ) - hetero_nn_0.guest.task_setting(runner_conf=guest_conf) - hetero_nn_0.hosts[0].task_setting(runner_conf=host_conf) + hetero_nn_0.guest.task_parameters(runner_conf=guest_conf) + hetero_nn_0.hosts[0].task_parameters(runner_conf=host_conf) hetero_nn_1 = HeteroNN( 'hetero_nn_1', @@ -84,15 +88,12 @@ def main(config="../../config.yaml", namespace=""): evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), metrics=['auc'], input_data=[hetero_nn_0.outputs['train_data_output']] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_nn_0) - pipeline.add_task(hetero_nn_1) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, hetero_nn_0, hetero_nn_1, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/hetero_nn/test_nn_binary_with_fedpass.py b/examples/pipeline/hetero_nn/test_nn_binary_with_fedpass.py index a090b106cd..adce69f219 100644 --- a/examples/pipeline/hetero_nn/test_nn_binary_with_fedpass.py +++ b/examples/pipeline/hetero_nn/test_nn_binary_with_fedpass.py @@ -15,14 +15,13 @@ # import argparse -from fate_test.utils import parse_summary_result from fate_client.pipeline.utils import test_utils from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.components.fate.nn.torch import nn, optim from fate_client.pipeline.components.fate.nn.torch.base import Sequential from fate_client.pipeline.components.fate.hetero_nn import HeteroNN, get_config_of_default_runner from fate_client.pipeline.components.fate.psi import PSI +from fate_client.pipeline.components.fate.reader import Reader from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments from fate_client.pipeline.components.fate import Evaluation from fate_client.pipeline.components.fate.nn.algo_params import FedPassArgument @@ -39,11 +38,16 @@ def main(config="../../config.yaml", namespace=""): pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) training_args = TrainingArguments( num_train_epochs=1, @@ -82,8 +86,8 @@ def main(config="../../config.yaml", namespace=""): train_data=psi_0.outputs['output_data'] ) - hetero_nn_0.guest.task_setting(runner_conf=guest_conf) - hetero_nn_0.hosts[0].task_setting(runner_conf=host_conf) + hetero_nn_0.guest.task_parameters(runner_conf=guest_conf) + hetero_nn_0.hosts[0].task_parameters(runner_conf=host_conf) hetero_nn_1 = HeteroNN( 'hetero_nn_1', @@ -93,15 +97,12 @@ def main(config="../../config.yaml", namespace=""): evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), metrics=['auc'], input_data=[hetero_nn_1.outputs['predict_data_output'], hetero_nn_0.outputs['train_data_output']] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_nn_0) - pipeline.add_task(hetero_nn_1) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, hetero_nn_0, hetero_nn_1, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py index 6b9ec11f2b..9645e7a9e3 100644 --- a/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -11,28 +10,30 @@ def main(config="../config.yaml", namespace=""): parties = config.parties guest = parties.guest[0] host = parties.host[0] - arbiter = parties.arbiter[0] - pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=2, max_bin=32, max_depth=3, goss=True, top_rate=0.2, he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],) evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), metrics=['auc'], input_data=[hetero_sbt_0.outputs['train_data_output']] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_sbt_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, hetero_sbt_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary_cv.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary_cv.py index dcfee1280b..4964cfc765 100644 --- a/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary_cv.py +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_binary_cv.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -11,15 +10,19 @@ def main(config="../config.yaml", namespace=""): parties = config.parties guest = parties.guest[0] host = parties.host[0] - arbiter = parties.arbiter[0] - pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, cv_param={"n_splits": 3}, @@ -27,8 +30,7 @@ def main(config="../config.yaml", namespace=""): cv_data=psi_0.outputs['output_data'] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_sbt_0) + pipeline.add_tasks([reader_0, psi_0, hetero_sbt_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_multi.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_multi.py index 42ef32a37b..20c7d5d16d 100644 --- a/examples/pipeline/hetero_secureboost/test_hetero_sbt_multi.py +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_multi.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -11,28 +10,30 @@ def main(config="../config.yaml", namespace=""): parties = config.parties guest = parties.guest[0] host = parties.host[0] - arbiter = parties.arbiter[0] - pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="student_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="student_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, objective='multi:ce', num_class=4, he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],) evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), default_eval_setting='multi', input_data=[hetero_sbt_0.outputs['train_data_output']] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_sbt_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, hetero_sbt_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py b/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py index ee01e490cb..16d6ffbb40 100644 --- a/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py +++ b/examples/pipeline/hetero_secureboost/test_hetero_sbt_regression.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation +from fate_client.pipeline.components.fate import HeteroSecureBoost, PSI, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -11,28 +10,30 @@ def main(config="../config.yaml", namespace=""): parties = config.parties guest = parties.guest[0] host = parties.host[0] - arbiter = parties.arbiter[0] - pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + pipeline = FateFlowPipeline().set_parties(guest=guest, host=host) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="student_hetero_guest", - namespace="experiment")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="student_hetero_host", - namespace="experiment")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) hetero_sbt_0 = HeteroSecureBoost('sbt_0', num_trees=3, max_bin=32, max_depth=3, objective='regression:l2', he_param={'kind': 'paillier', 'key_length': 1024}, train_data=psi_0.outputs['output_data'],) evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), metrics=['rmse'], input_data=[hetero_sbt_0.outputs['train_data_output']] ) - pipeline.add_task(psi_0) - pipeline.add_task(hetero_sbt_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, hetero_sbt_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/homo_lr/test_homo_lr_binary.py b/examples/pipeline/homo_lr/test_homo_lr_binary.py index fa26353717..1f4b1c02d4 100644 --- a/examples/pipeline/homo_lr/test_homo_lr_binary.py +++ b/examples/pipeline/homo_lr/test_homo_lr_binary.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HomoLR, Evaluation +from fate_client.pipeline.components.fate import HomoLR, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -14,23 +13,30 @@ def main(config="../config.yaml", namespace=""): arbiter = parties.arbiter[0] pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_host" + ) + homo_lr_0 = HomoLR( "homo_lr_0", epochs=10, - batch_size=16 + batch_size=16, + train_data=reader_0.outputs["output_data"] ) - homo_lr_0.guest.task_setting(train_data=DataWarehouseChannel(name="breast_homo_guest", namespace="experiment")) - homo_lr_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name="breast_homo_host", namespace="experiment")) evaluation_0 = Evaluation( 'eval_0', metrics=['auc'], input_data=[homo_lr_0.outputs['train_output_data']] ) - - pipeline.add_task(homo_lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, homo_lr_0, evaluation_0]) pipeline.compile() pipeline.fit() print (pipeline.get_task_info("homo_lr_0").get_output_data()) diff --git a/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py b/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py index eafa50c65b..b33cb7fc1c 100644 --- a/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py +++ b/examples/pipeline/homo_lr/test_homo_lr_multi_ovr.py @@ -1,7 +1,6 @@ import argparse -from fate_client.pipeline.components.fate import HomoLR, Evaluation +from fate_client.pipeline.components.fate import HomoLR, Evaluation, Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -14,25 +13,31 @@ def main(config="../config.yaml", namespace=""): arbiter = parties.arbiter[0] pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_host" + ) homo_lr_0 = HomoLR( "homo_lr_0", epochs=10, batch_size=16, ovr=True, - label_num=4 + label_num=4, + train_data=reader_0.outputs["output_data"] ) - homo_lr_0.guest.task_setting(train_data=DataWarehouseChannel(name="vehicle_scale_homo_guest", namespace="experiment")) - homo_lr_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name="vehicle_scale_homo_host", namespace="experiment")) evaluation_0 = Evaluation( 'eval_0', default_eval_setting='multi', input_data=[homo_lr_0.outputs['train_output_data']] ) - - pipeline.add_task(homo_lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, homo_lr_0, evaluation_0]) pipeline.compile() pipeline.fit() print (pipeline.get_task_info("homo_lr_0").get_output_data()) diff --git a/examples/pipeline/homo_nn/test_nn_binary.py b/examples/pipeline/homo_nn/test_nn_binary.py index fd3e7bca40..5442de48fd 100644 --- a/examples/pipeline/homo_nn/test_nn_binary.py +++ b/examples/pipeline/homo_nn/test_nn_binary.py @@ -17,8 +17,8 @@ import argparse from fate_client.pipeline.utils import test_utils from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline.components.fate.reader import Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.components.fate.nn.torch import nn, optim from fate_client.pipeline.components.fate.nn.torch.base import Sequential from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner @@ -40,10 +40,18 @@ def main(config="../../config.yaml", namespace=""): out_feat = 16 lr = 0.01 - guest_train_data = {"name": "breast_homo_guest", "namespace": f"experiment{namespace}"} - host_train_data = {"name": "breast_homo_host", "namespace": f"experiment{namespace}"} pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_homo_host" + ) + conf = get_config_of_default_runner( algo='fedavg', model=Sequential( @@ -62,31 +70,24 @@ def main(config="../../config.yaml", namespace=""): homo_nn_0 = HomoNN( 'nn_0', - runner_conf=conf + runner_conf=conf, + train_data=reader_0.outputs["output_data"] ) - homo_nn_0.guest.task_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) - homo_nn_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) - homo_nn_1 = HomoNN( 'nn_1', - predict_model_input=homo_nn_0.outputs['train_model_output'] + predict_model_input=homo_nn_0.outputs['train_model_output'], + test_data=reader_0.outputs["output_data"] ) - homo_nn_1.guest.task_setting(test_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) - homo_nn_1.hosts[0].task_setting(test_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) - evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest, host=host), metrics=['auc'], input_data=[homo_nn_1.outputs['predict_data_output']] ) - - pipeline.add_task(homo_nn_0) - pipeline.add_task(homo_nn_1) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, homo_nn_0, homo_nn_1, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/homo_nn/test_nn_multi.py b/examples/pipeline/homo_nn/test_nn_multi.py index bf0ca773ee..8529171e03 100644 --- a/examples/pipeline/homo_nn/test_nn_multi.py +++ b/examples/pipeline/homo_nn/test_nn_multi.py @@ -16,14 +16,11 @@ import argparse from fate_test.utils import parse_summary_result -from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline.components.fate.reader import Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.components.fate.nn.torch import nn, optim -from fate_client.pipeline.components.fate.nn.torch.base import Sequential from fate_client.pipeline.components.fate.nn.loader import ModelLoader from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments @@ -44,10 +41,18 @@ def main(config="../../config.yaml", namespace=""): lr = 0.01 class_num=4 - guest_train_data = {"name": "vehicle_scale_homo_guest", "namespace": f"experiment{namespace}"} - host_train_data = {"name": "vehicle_scale_homo_host", "namespace": f"experiment{namespace}"} pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_homo_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_homo_host" + ) + conf = get_config_of_default_runner( algo='fedavg', model=ModelLoader('multi_model', 'Multi', feat=in_feat, class_num=class_num), @@ -61,22 +66,18 @@ def main(config="../../config.yaml", namespace=""): homo_nn_0 = HomoNN( 'nn_0', - runner_conf=conf + runner_conf=conf, + train_data=reader_0.outputs["output_data"] ) - homo_nn_0.guest.task_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) - homo_nn_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) - evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), default_eval_setting='multi', input_data=[homo_nn_0.outputs['train_data_output']] ) - - pipeline.add_task(homo_nn_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, homo_nn_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/homo_nn/test_nn_regression.py b/examples/pipeline/homo_nn/test_nn_regression.py index 2a945ece2c..7d4e5e72c5 100644 --- a/examples/pipeline/homo_nn/test_nn_regression.py +++ b/examples/pipeline/homo_nn/test_nn_regression.py @@ -16,12 +16,10 @@ import argparse from fate_test.utils import parse_summary_result -from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils from fate_client.pipeline.components.fate.evaluation import Evaluation +from fate_client.pipeline.components.fate.reader import Reader from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.components.fate.nn.torch import nn, optim from fate_client.pipeline.components.fate.nn.torch.base import Sequential from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner @@ -43,10 +41,18 @@ def main(config="../../config.yaml", namespace=""): out_feat = 10 lr = 0.01 - guest_train_data = {"name": "student_homo_guest", "namespace": f"experiment{namespace}"} - host_train_data = {"name": "student_homo_host", "namespace": f"experiment{namespace}"} pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="student_homo_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="student_homo_host" + ) + conf = get_config_of_default_runner( algo='fedavg', model=Sequential( @@ -64,22 +70,18 @@ def main(config="../../config.yaml", namespace=""): homo_nn_0 = HomoNN( 'nn_0', - runner_conf=conf + runner_conf=conf, + train_data=reader_0.outputs["output_data"] ) - homo_nn_0.guest.task_setting(train_data=DataWarehouseChannel(name=guest_train_data["name"], namespace=guest_train_data["namespace"])) - homo_nn_0.hosts[0].task_setting(train_data=DataWarehouseChannel(name=host_train_data["name"], namespace=host_train_data["namespace"])) - evaluation_0 = Evaluation( 'eval_0', - runtime_roles=['guest'], + runtime_parties=dict(guest=guest), metrics=['rmse'], input_data=[homo_nn_0.outputs['train_data_output']] ) - - pipeline.add_task(homo_nn_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, homo_nn_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/multi_model/test_multi.py b/examples/pipeline/multi_model/test_multi.py index d65b6b4b67..98b7ba2661 100644 --- a/examples/pipeline/multi_model/test_multi.py +++ b/examples/pipeline/multi_model/test_multi.py @@ -16,8 +16,7 @@ from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import PSI, HeteroFeatureSelection, HeteroFeatureBinning, \ - FeatureScale, Union, DataSplit, CoordinatedLR, CoordinatedLinR, Statistics, Sample, Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel + FeatureScale, Union, DataSplit, CoordinatedLR, CoordinatedLinR, Statistics, Sample, Evaluation, Reader from fate_client.pipeline.utils import test_utils @@ -35,11 +34,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", input_data=psi_0.outputs["output_data"], train_size=0.8, test_size=0.2, random_state=42) @@ -78,23 +82,12 @@ def main(config="../config.yaml", namespace=""): evaluation_0 = Evaluation("evaluation_0", input_data=lr_0.outputs["train_output_data"], default_eval_setting="binary", - runtime_roles=["guest"]) + runtime_parties=dict(guest=guest)) evaluation_1 = Evaluation("evaluation_1", input_data=linr_0.outputs["train_output_data"], default_eval_setting="regression", - runtime_roles=["guest"]) - pipeline.add_task(psi_0) - pipeline.add_task(data_split_0) - pipeline.add_task(union_0) - pipeline.add_task(sample_0) - pipeline.add_task(binning_0) - pipeline.add_task(statistics_0) - pipeline.add_task(selection_0) - pipeline.add_task(scale_0) - pipeline.add_task(selection_1) - pipeline.add_task(lr_0) - pipeline.add_task(linr_0) - pipeline.add_task(evaluation_0) - pipeline.add_task(evaluation_1) + runtime_parties=dict(guest=guest)) + pipeline.add_tasks([reader_0, psi_0, data_split_0, union_0, sample_0, binning_0, statistics_0, selection_0, + scale_0, selection_1, lr_0, linr_0, evaluation_0, evaluation_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() @@ -106,16 +99,19 @@ def main(config="../config.yaml", namespace=""): pipeline.deploy([psi_0, selection_0]) predict_pipeline = FateFlowPipeline() - + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/multi_model/test_multi_preprocessing.py b/examples/pipeline/multi_model/test_multi_preprocessing.py index b1e404c791..6bc46f4479 100644 --- a/examples/pipeline/multi_model/test_multi_preprocessing.py +++ b/examples/pipeline/multi_model/test_multi_preprocessing.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import DataSplit, PSI, Sample, FeatureScale -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import DataSplit, PSI, Sample, FeatureScale, Reader from fate_client.pipeline.utils import test_utils @@ -34,17 +33,28 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, @@ -77,15 +87,8 @@ def main(config="../config.yaml", namespace=""): feature_range={"x0": [-1, 1]}, scale_col=["x0", "x1", "x3"], train_data=psi_0.outputs["output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(data_split_0) - pipeline.add_task(data_split_1) - pipeline.add_task(sample_0) - pipeline.add_task(sample_1) - pipeline.add_task(feature_scale_0) - - # pipeline.add_task(hetero_feature_binning_0) + + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1, sample_0, sample_1, feature_scale_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/multi_model/test_multi_w_predict.py b/examples/pipeline/multi_model/test_multi_w_predict.py index e80d83b0d6..49c1e008a8 100644 --- a/examples/pipeline/multi_model/test_multi_w_predict.py +++ b/examples/pipeline/multi_model/test_multi_w_predict.py @@ -17,8 +17,7 @@ from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import PSI, CoordinatedLR, Evaluation, \ - HeteroFeatureBinning, HeteroFeatureSelection -from fate_client.pipeline.interface import DataWarehouseChannel + HeteroFeatureBinning, HeteroFeatureSelection, Reader from fate_client.pipeline.utils import test_utils @@ -36,11 +35,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) binning_0 = HeteroFeatureBinning("binning_0", method="quantile", @@ -63,17 +67,11 @@ def main(config="../config.yaml", namespace=""): "total_iters": 100}}) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(binning_0) - pipeline.add_task(selection_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) - - # pipeline.add_task(hetero_feature_binning_0) + pipeline.add_tasks([reader_0, psi_0, binning_0, selection_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() @@ -81,14 +79,19 @@ def main(config="../config.yaml", namespace=""): pipeline.deploy([psi_0, selection_0, lr_0]) predict_pipeline = FateFlowPipeline() - + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] - predict_pipeline.add_task(deployed_pipeline) + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() predict_pipeline.predict() diff --git a/examples/pipeline/sample/test_sample.py b/examples/pipeline/sample/test_sample.py index cf5b3ba583..d37996c786 100644 --- a/examples/pipeline/sample/test_sample.py +++ b/examples/pipeline/sample/test_sample.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Sample, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import Sample, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -33,17 +32,26 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_0.outputs["output_data"]) sample_0 = Sample("sample_0", frac={0: 0.5}, @@ -58,10 +66,7 @@ def main(config="../config.yaml", namespace=""): input_data=psi_0.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(sample_0) - pipeline.add_task(sample_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, sample_0, sample_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() diff --git a/examples/pipeline/sample/test_sample_multi_host.py b/examples/pipeline/sample/test_sample_multi_host.py index 58e63d80db..d0472c1c44 100644 --- a/examples/pipeline/sample/test_sample_multi_host.py +++ b/examples/pipeline/sample/test_sample_multi_host.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Sample, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import Sample, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -33,21 +32,26 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_0.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1.hosts[1].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_0.outputs["output_data"]) sample_0 = Sample("sample_0", frac={0: 0.8, 1: 0.5}, @@ -62,10 +66,7 @@ def main(config="../config.yaml", namespace=""): input_data=psi_0.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(sample_0) - pipeline.add_task(sample_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, sample_0, sample_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() diff --git a/examples/pipeline/sample/test_sample_unilateral.py b/examples/pipeline/sample/test_sample_unilateral.py index 804deca9e0..81cce74ec0 100644 --- a/examples/pipeline/sample/test_sample_unilateral.py +++ b/examples/pipeline/sample/test_sample_unilateral.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Sample, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import Sample, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -32,37 +31,44 @@ def main(config="../config.yaml", namespace=""): pipeline.conf.set("task_cores", config.task_cores) if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[[0, 1]].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_0.outputs["output_data"]) sample_0 = Sample("sample_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), frac={0: 0.5}, replace=False, hetero_sync=False, input_data=psi_0.outputs["output_data"]) sample_1 = Sample("sample_1", - runtime_roles=["host"], + runtime_parties=dict(host=host), n=1000, replace=True, hetero_sync=False, input_data=psi_0.outputs["output_data"] ) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(sample_0) - pipeline.add_task(sample_1) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, sample_0, sample_1]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() diff --git a/examples/pipeline/sshe_linr/test_linr.py b/examples/pipeline/sshe_linr/test_linr.py index f990f17cce..4991f208e7 100644 --- a/examples/pipeline/sshe_linr/test_linr.py +++ b/examples/pipeline/sshe_linr/test_linr.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,10 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = SSHELinR("linr_0", epochs=10, batch_size=100, @@ -48,13 +46,11 @@ def main(config="../config.yaml", namespace=""): reveal_loss_freq=1, learning_rate=0.1) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=linr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, linr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() @@ -63,15 +59,12 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_1.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/sshe_linr/test_linr_cv.py b/examples/pipeline/sshe_linr/test_linr_cv.py index 7b4f31cd30..b0ef69afb5 100644 --- a/examples/pipeline/sshe_linr/test_linr_cv.py +++ b/examples/pipeline/sshe_linr/test_linr_cv.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import SSHELinR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import SSHELinR, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,10 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = SSHELinR("linr_0", epochs=10, batch_size=None, @@ -50,8 +48,7 @@ def main(config="../config.yaml", namespace=""): reveal_loss_freq=1, ) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) + pipeline.add_tasks([reader_0, psi_0, linr_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/sshe_linr/test_linr_validate.py b/examples/pipeline/sshe_linr/test_linr_validate.py index bb76fe1b56..2af85bb579 100644 --- a/examples/pipeline/sshe_linr/test_linr_validate.py +++ b/examples/pipeline/sshe_linr/test_linr_validate.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation, DataSplit -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import SSHELinR, PSI, Evaluation, DataSplit, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,10 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.8, validate_size=0.2, @@ -53,14 +51,11 @@ def main(config="../config.yaml", namespace=""): reveal_loss_freq=1, learning_rate=0.1) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=linr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(data_split_0) - pipeline.add_task(linr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, data_split_0, linr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() @@ -69,15 +64,12 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) + reader_1.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_1.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/sshe_linr/test_linr_warm_start.py b/examples/pipeline/sshe_linr/test_linr_warm_start.py index d62f1546ed..e43936779b 100644 --- a/examples/pipeline/sshe_linr/test_linr_warm_start.py +++ b/examples/pipeline/sshe_linr/test_linr_warm_start.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import Evaluation, Reader from fate_client.pipeline.components.fate import SSHELinR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -34,18 +33,17 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="motor_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="motor_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) + reader_0.guest.task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_guest") + reader_0.hosts[0].task_parameters(namespace=f"experiment{namespace}", name="motor_hetero_host") + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) linr_0 = SSHELinR("linr_0", epochs=4, batch_size=None, init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], learning_rate=0.05, - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) @@ -70,15 +68,11 @@ def main(config="../config.yaml", namespace=""): ) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="regression", input_data=[linr_1.outputs["train_output_data"], linr_2.outputs["train_output_data"]]) - pipeline.add_task(psi_0) - pipeline.add_task(linr_0) - pipeline.add_task(linr_1) - pipeline.add_task(linr_2) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, linr_0, linr_1, linr_2, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/sshe_lr/test_lr.py b/examples/pipeline/sshe_lr/test_lr.py index ed375881fc..ccfeea78e7 100644 --- a/examples/pipeline/sshe_lr/test_lr.py +++ b/examples/pipeline/sshe_lr/test_lr.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import Evaluation, Reader from fate_client.pipeline.components.fate import SSHELR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -35,11 +34,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = SSHELR("lr_0", learning_rate=0.05, epochs=10, @@ -51,13 +55,11 @@ def main(config="../config.yaml", namespace=""): reveal_loss_freq=1, ) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() pipeline.fit() @@ -66,15 +68,20 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/sshe_lr/test_lr_cv.py b/examples/pipeline/sshe_lr/test_lr_cv.py index fabbd10b19..f10e6c870a 100644 --- a/examples/pipeline/sshe_lr/test_lr_cv.py +++ b/examples/pipeline/sshe_lr/test_lr_cv.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import SSHELR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import SSHELR, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -33,11 +32,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = SSHELR("lr_0", learning_rate=0.15, epochs=2, @@ -50,8 +54,7 @@ def main(config="../config.yaml", namespace=""): reveal_loss_freq=1, ) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) + pipeline.add_tasks([reader_0, psi_0, lr_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() diff --git a/examples/pipeline/sshe_lr/test_lr_multi_class.py b/examples/pipeline/sshe_lr/test_lr_multi_class.py index c739a86631..607220ab77 100644 --- a/examples/pipeline/sshe_lr/test_lr_multi_class.py +++ b/examples/pipeline/sshe_lr/test_lr_multi_class.py @@ -17,8 +17,7 @@ from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import Evaluation -from fate_client.pipeline.components.fate import SSHELR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import SSHELR, PSI, Reader from fate_client.pipeline.utils import test_utils @@ -35,11 +34,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = SSHELR("lr_0", learning_rate=0.15, epochs=10, @@ -51,14 +55,12 @@ def main(config="../config.yaml", namespace=""): train_data=psi_0.outputs["output_data"]) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="multi", predict_column_name='predict_result', input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) @@ -68,15 +70,21 @@ def main(config="../config.yaml", namespace=""): predict_pipeline = FateFlowPipeline() + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="vehicle_scale_hetero_host" + ) + deployed_pipeline = pipeline.get_deployed_pipeline() - deployed_pipeline.psi_0.guest.task_setting( - input_data=DataWarehouseChannel(name="vehicle_scale_hetero_guest", - namespace=f"experiment{namespace}")) - deployed_pipeline.psi_0.hosts[0].task_setting( - input_data=DataWarehouseChannel(name="vehicle_scale_hetero_host", - namespace=f"experiment{namespace}")) - - predict_pipeline.add_task(deployed_pipeline) + deployed_pipeline.psi_0.input_data = reader_1.outputs["output_data"] + + predict_pipeline.add_tasks([reader_1, deployed_pipeline]) + predict_pipeline.compile() # print("\n\n\n") # print(predict_pipeline.compile().get_dag()) diff --git a/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py index da3d2ba7a6..0d349bf978 100644 --- a/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py +++ b/examples/pipeline/sshe_lr/test_lr_predict_w_torch.py @@ -17,9 +17,8 @@ import torch from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import Evaluation, Reader from fate_client.pipeline.components.fate import SSHELR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -46,11 +45,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = SSHELR("lr_0", epochs=10, batch_size=300, @@ -60,13 +64,11 @@ def main(config="../config.yaml", namespace=""): ) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, evaluation_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/sshe_lr/test_lr_validate.py b/examples/pipeline/sshe_lr/test_lr_validate.py index 29e1c3ef9e..5827766d2e 100644 --- a/examples/pipeline/sshe_lr/test_lr_validate.py +++ b/examples/pipeline/sshe_lr/test_lr_validate.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import Evaluation, Reader from fate_client.pipeline.components.fate import SSHELR, PSI, DataSplit -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -35,11 +34,16 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.8, validate_size=0.2, @@ -56,14 +60,11 @@ def main(config="../config.yaml", namespace=""): ) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=lr_0.outputs["train_output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(data_split_0) - pipeline.add_task(lr_0) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, data_split_0, lr_0, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/sshe_lr/test_lr_warm_start.py b/examples/pipeline/sshe_lr/test_lr_warm_start.py index c17e75722e..4bf3420605 100644 --- a/examples/pipeline/sshe_lr/test_lr_warm_start.py +++ b/examples/pipeline/sshe_lr/test_lr_warm_start.py @@ -16,9 +16,8 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import Evaluation +from fate_client.pipeline.components.fate import Evaluation, Reader from fate_client.pipeline.components.fate import SSHELR, PSI -from fate_client.pipeline.interface import DataWarehouseChannel from fate_client.pipeline.utils import test_utils @@ -34,18 +33,23 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) lr_0 = SSHELR("lr_0", epochs=4, batch_size=None, learning_rate=0.05, init_param={"fit_intercept": True, "method": "zeros"}, train_data=psi_0.outputs["output_data"], - reveal_every_epoch=True, + reveal_every_epoch=False, early_stop="diff", reveal_loss_freq=1, ) @@ -70,15 +74,11 @@ def main(config="../config.yaml", namespace=""): ) evaluation_0 = Evaluation("evaluation_0", - runtime_roles=["guest"], + runtime_parties=dict(guest=guest), default_eval_setting="binary", input_data=[lr_1.outputs["train_output_data"], lr_2.outputs["train_output_data"]]) - pipeline.add_task(psi_0) - pipeline.add_task(lr_0) - pipeline.add_task(lr_1) - pipeline.add_task(lr_2) - pipeline.add_task(evaluation_0) + pipeline.add_tasks([reader_0, psi_0, lr_0, lr_1, lr_2, evaluation_0]) pipeline.compile() # print(pipeline.get_dag()) diff --git a/examples/pipeline/statistics/test_statistics.py b/examples/pipeline/statistics/test_statistics.py index 6d014f71ff..5a358a45f4 100644 --- a/examples/pipeline/statistics/test_statistics.py +++ b/examples/pipeline/statistics/test_statistics.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,18 +32,22 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) statistics_0 = Statistics("statistics_0", input_data=psi_0.outputs["output_data"], metrics=["mean", "std", "0%", "25%", "median", "75%", "100%", "missing_ratio"]) - pipeline.add_task(psi_0) - pipeline.add_task(statistics_0) + pipeline.add_tasks([reader_0, psi_0, statistics_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() diff --git a/examples/pipeline/statistics/test_statistics_default.py b/examples/pipeline/statistics/test_statistics_default.py index f756506e0d..85ec9fc8bc 100644 --- a/examples/pipeline/statistics/test_statistics_default.py +++ b/examples/pipeline/statistics/test_statistics_default.py @@ -15,8 +15,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import PSI, Statistics -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import PSI, Statistics, Reader from fate_client.pipeline.utils import test_utils @@ -33,18 +32,22 @@ def main(config=".../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) statistics_0 = Statistics("statistics_0", skip_col=["x0", "x3"], input_data=psi_0.outputs["output_data"]) - pipeline.add_task(psi_0) - pipeline.add_task(statistics_0) + pipeline.add_tasks([reader_0, psi_0, statistics_0]) pipeline.compile() pipeline.fit() diff --git a/examples/pipeline/union/test_union.py b/examples/pipeline/union/test_union.py index 3c88275cf5..ff6c6efbd9 100644 --- a/examples/pipeline/union/test_union.py +++ b/examples/pipeline/union/test_union.py @@ -16,8 +16,7 @@ import argparse from fate_client.pipeline import FateFlowPipeline -from fate_client.pipeline.components.fate import DataSplit, PSI, Union -from fate_client.pipeline.interface import DataWarehouseChannel +from fate_client.pipeline.components.fate import DataSplit, PSI, Union, Reader from fate_client.pipeline.utils import test_utils @@ -34,17 +33,26 @@ def main(config="../config.yaml", namespace=""): if config.timeout: pipeline.conf.set("timeout", config.timeout) - psi_0 = PSI("psi_0") - psi_0.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_0.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) - - psi_1 = PSI("psi_1") - psi_1.guest.task_setting(input_data=DataWarehouseChannel(name="breast_hetero_guest", - namespace=f"experiment{namespace}")) - psi_1.hosts[0].task_setting(input_data=DataWarehouseChannel(name="breast_hetero_host", - namespace=f"experiment{namespace}")) + reader_0 = Reader("reader_0") + reader_0.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_0.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + reader_1 = Reader("reader_1") + reader_1.guest.task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_guest" + ) + reader_1.hosts[0].task_parameters( + namespace=f"experiment{namespace}", + name="breast_hetero_host" + ) + psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) + psi_1 = PSI("psi_1", input_data=reader_0.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, @@ -59,11 +67,7 @@ def main(config="../config.yaml", namespace=""): union_0 = Union("union_0", input_data_list=[data_split_0.outputs["train_output_data"], data_split_0.outputs["test_output_data"]]) - pipeline.add_task(psi_0) - pipeline.add_task(psi_1) - pipeline.add_task(data_split_0) - pipeline.add_task(data_split_1) - pipeline.add_task(union_0) + pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1, union_0]) # pipeline.add_task(hetero_feature_binning_0) pipeline.compile() From 7857043151933d91ae9ee2213bdd53f866c6cd60 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Wed, 13 Dec 2023 11:42:30 +0800 Subject: [PATCH 36/42] reader cpn update Signed-off-by: mgqa34 --- python/fate/components/components/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/fate/components/components/reader.py b/python/fate/components/components/reader.py index d686f8f198..c7c8059dea 100644 --- a/python/fate/components/components/reader.py +++ b/python/fate/components/components/reader.py @@ -21,6 +21,6 @@ def reader( role: Role, name: cpn.parameter(type=str, default=None, optional=False), namespace: cpn.parameter(type=str, default=None, optional=False), - data_output: cpn.data_unresolved_output(), + output_data: cpn.data_unresolved_output(), ): - data_output.write_metadata({}, name=name, namespace=namespace) + output_data.write_metadata({}, name=name, namespace=namespace) From 937ba76135d36a34169c71c8d6456ac3403e7d4c Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 13 Dec 2023 13:18:14 +0800 Subject: [PATCH 37/42] fix protocol config and expose cipher options for sshe layers Signed-off-by: sagewe --- python/fate/arch/context/_cipher.py | 2 +- .../arch/protocol/mpc/nn/sshe/linr_layer.py | 7 ++-- .../arch/protocol/mpc/nn/sshe/lr_layer.py | 7 ++-- .../arch/protocol/mpc/nn/sshe/nn_layer.py | 3 +- python/fate/arch/protocol/phe/ou.py | 35 ++----------------- python/fate/arch/protocol/phe/paillier.py | 35 ++----------------- python/fate/test/test_ou.py | 34 ++++++++++++++++++ python/fate/test/test_paillier.py | 33 +++++++++++++++++ 8 files changed, 84 insertions(+), 72 deletions(-) create mode 100644 python/fate/test/test_ou.py create mode 100644 python/fate/test/test_paillier.py diff --git a/python/fate/arch/context/_cipher.py b/python/fate/arch/context/_cipher.py index cbde240952..781fc6888e 100644 --- a/python/fate/arch/context/_cipher.py +++ b/python/fate/arch/context/_cipher.py @@ -106,7 +106,7 @@ def setup(self, options: typing.Optional[dict] = None): key_size = self.key_length else: kind = options.get("kind", self.kind) - key_size = options.get("key_length", 1024) + key_size = options.get("key_length", self.key_length) if kind == "paillier": if not cfg.safety.phe.paillier.allow: diff --git a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py index 35651e2ed3..44df85c605 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/linr_layer.py @@ -21,6 +21,7 @@ def __init__( wb_init_fn: typing.Callable[[typing.Tuple], torch.Tensor], precision_bits=None, sync_shape=True, + cipher_options=None, ): self.ctx = ctx self.rank_a = rank_a @@ -49,7 +50,7 @@ def __init__( self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a) self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b) - self.phe_cipher = ctx.cipher.phe.setup() + self.phe_cipher = ctx.cipher.phe.setup(options=cipher_options) self.precision_bits = precision_bits @auto_trace(annotation="[z|rank_b] = [xa|rank_a] * + [xb|rank_b] * ") @@ -145,12 +146,12 @@ def __call__(self, dz): class SSHELinearRegressionLossLayer: - def __init__(self, ctx: Context, rank_a, rank_b): + def __init__(self, ctx: Context, rank_a, rank_b, cipher_options=None): self.ctx = ctx self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], f"{ctx.namespace.federation_tag}.sshe_loss_layer") self.rank_a = rank_a self.rank_b = rank_b - self.phe_cipher = ctx.cipher.phe.setup() + self.phe_cipher = ctx.cipher.phe.setup(options=cipher_options) def forward(self, z, y): dz = z.clone() diff --git a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py index 26fdff3af5..cf7e994c36 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/lr_layer.py @@ -21,6 +21,7 @@ def __init__( wb_init_fn: typing.Callable[[typing.Tuple], torch.Tensor], precision_bits=None, sync_shape=True, + cipher_options=None, ): self.ctx = ctx self.rank_a = rank_a @@ -49,7 +50,7 @@ def __init__( self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a) self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b) - self.phe_cipher = ctx.cipher.phe.setup() + self.phe_cipher = ctx.cipher.phe.setup(options=cipher_options) self.precision_bits = precision_bits @auto_trace(annotation="[z|rank_b] = 0.25 * ([xa|rank_a] * + [xb|rank_b] * ) + 0.5") @@ -147,12 +148,12 @@ def __call__(self, dz): class SSHELogisticRegressionLossLayer: - def __init__(self, ctx: Context, rank_a, rank_b): + def __init__(self, ctx: Context, rank_a, rank_b, cipher_options=None): self.ctx = ctx self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], f"{ctx.namespace.federation_tag}.sshe_loss_layer") self.rank_a = rank_a self.rank_b = rank_b - self.phe_cipher = ctx.cipher.phe.setup() + self.phe_cipher = ctx.cipher.phe.setup(options=cipher_options) @auto_trace(annotation=" = - y") def forward(self, z, y): diff --git a/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py b/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py index 94e1e3a7e8..3b264a50d4 100644 --- a/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py +++ b/python/fate/arch/protocol/mpc/nn/sshe/nn_layer.py @@ -80,12 +80,13 @@ def __init__( wa_init_fn, wb_init_fn, precision_bits=None, + cipher_options=None, ): self.ctx = ctx self.group = ctx.mpc.communicator.new_group([rank_a, rank_b], "sshe_nn_aggregator_layer") self.wa = ctx.mpc.init_tensor(shape=(in_features_a, out_features), init_func=wa_init_fn, src=rank_a) self.wb = ctx.mpc.init_tensor(shape=(in_features_b, out_features), init_func=wb_init_fn, src=rank_b) - self.phe_cipher = ctx.cipher.phe.setup() + self.phe_cipher = ctx.cipher.phe.setup(options=cipher_options) self.rank_a = rank_a self.rank_b = rank_b self.precision_bits = precision_bits diff --git a/python/fate/arch/protocol/phe/ou.py b/python/fate/arch/protocol/phe/ou.py index 5223c8985a..e2a7412084 100644 --- a/python/fate/arch/protocol/phe/ou.py +++ b/python/fate/arch/protocol/phe/ou.py @@ -1,11 +1,11 @@ from typing import List, Optional, Tuple import torch -from fate_utils.ou import PK as _PK -from fate_utils.ou import SK as _SK +from fate_utils.ou import CiphertextVector, PlaintextVector from fate_utils.ou import Coder as _Coder from fate_utils.ou import Evaluator as _Evaluator -from fate_utils.ou import CiphertextVector, PlaintextVector +from fate_utils.ou import PK as _PK +from fate_utils.ou import SK as _SK from fate_utils.ou import keygen as _keygen from .type import TensorEvaluator @@ -384,32 +384,3 @@ def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: return a.pack_squeeze(pack_num, shift_bit, pk.pk) -def test_pack_float(): - offset_bit = 32 - precision = 16 - coder = Coder(_Coder()) - vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - packed = coder.pack_floats(vec, offset_bit, 2, precision) - unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) - assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) - - -def test_pack_squeeze(): - offset_bit = 32 - precision = 16 - pack_num = 2 - pack_packed_num = 2 - vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) - sk, pk, coder = keygen(1024) - a = coder.pack_floats(vec1, offset_bit, pack_num, precision) - ea = pk.encrypt_encoded(a, obfuscate=False) - b = coder.pack_floats(vec2, offset_bit, pack_num, precision) - eb = pk.encrypt_encoded(b, obfuscate=False) - ec = evaluator.add(ea, eb, pk) - - # pack packed encrypted - ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) - c_pack = sk.decrypt_to_encoded(ec_pack) - c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) - assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/python/fate/arch/protocol/phe/paillier.py b/python/fate/arch/protocol/phe/paillier.py index 18c29453b5..0a8494df95 100644 --- a/python/fate/arch/protocol/phe/paillier.py +++ b/python/fate/arch/protocol/phe/paillier.py @@ -1,11 +1,11 @@ from typing import List, Optional, Tuple import torch -from fate_utils.paillier import PK as _PK -from fate_utils.paillier import SK as _SK +from fate_utils.paillier import CiphertextVector, PlaintextVector from fate_utils.paillier import Coder as _Coder from fate_utils.paillier import Evaluator as _Evaluator -from fate_utils.paillier import CiphertextVector, PlaintextVector +from fate_utils.paillier import PK as _PK +from fate_utils.paillier import SK as _SK from fate_utils.paillier import keygen as _keygen from .type import TensorEvaluator @@ -384,32 +384,3 @@ def pack_squeeze(a: EV, pack_num: int, shift_bit: int, pk: PK) -> EV: return a.pack_squeeze(pack_num, shift_bit, pk.pk) -def test_pack_float(): - offset_bit = 32 - precision = 16 - coder = Coder(_Coder()) - vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - packed = coder.pack_floats(vec, offset_bit, 2, precision) - unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) - assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) - - -def test_pack_squeeze(): - offset_bit = 32 - precision = 16 - pack_num = 2 - pack_packed_num = 2 - vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) - sk, pk, coder = keygen(1024) - a = coder.pack_floats(vec1, offset_bit, pack_num, precision) - ea = pk.encrypt_encoded(a, obfuscate=False) - b = coder.pack_floats(vec2, offset_bit, pack_num, precision) - eb = pk.encrypt_encoded(b, obfuscate=False) - ec = evaluator.add(ea, eb, pk) - - # pack packed encrypted - ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) - c_pack = sk.decrypt_to_encoded(ec_pack) - c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) - assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/python/fate/test/test_ou.py b/python/fate/test/test_ou.py new file mode 100644 index 0000000000..0fda66945f --- /dev/null +++ b/python/fate/test/test_ou.py @@ -0,0 +1,34 @@ +from fate_utils.ou import Coder as _Coder + +from fate.arch.protocol.phe.ou import * + + +def test_pack_float(): + offset_bit = 32 + precision = 16 + coder = Coder(_Coder()) + vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + packed = coder.pack_floats(vec, offset_bit, 2, precision) + unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) + assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) + + +def test_pack_squeeze(): + offset_bit = 32 + precision = 16 + pack_num = 2 + pack_packed_num = 2 + vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) + sk, pk, coder = keygen(1024) + a = coder.pack_floats(vec1, offset_bit, pack_num, precision) + ea = pk.encrypt_encoded(a, obfuscate=False) + b = coder.pack_floats(vec2, offset_bit, pack_num, precision) + eb = pk.encrypt_encoded(b, obfuscate=False) + ec = evaluator.add(ea, eb, pk) + + # pack packed encrypted + ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) + c_pack = sk.decrypt_to_encoded(ec_pack) + c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) + assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) diff --git a/python/fate/test/test_paillier.py b/python/fate/test/test_paillier.py new file mode 100644 index 0000000000..52caa21ed3 --- /dev/null +++ b/python/fate/test/test_paillier.py @@ -0,0 +1,33 @@ +from fate_utils.paillier import Coder as _Coder +from fate.arch.protocol.phe.paillier import * + + +def test_pack_float(): + offset_bit = 32 + precision = 16 + coder = Coder(_Coder()) + vec = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + packed = coder.pack_floats(vec, offset_bit, 2, precision) + unpacked = coder.unpack_floats(packed, offset_bit, 2, precision, 5) + assert torch.allclose(vec, unpacked, rtol=1e-3, atol=1e-3) + + +def test_pack_squeeze(): + offset_bit = 32 + precision = 16 + pack_num = 2 + pack_packed_num = 2 + vec1 = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + vec2 = torch.tensor([0.6, 0.7, 0.8, 0.9, 1.0]) + sk, pk, coder = keygen(1024) + a = coder.pack_floats(vec1, offset_bit, pack_num, precision) + ea = pk.encrypt_encoded(a, obfuscate=False) + b = coder.pack_floats(vec2, offset_bit, pack_num, precision) + eb = pk.encrypt_encoded(b, obfuscate=False) + ec = evaluator.add(ea, eb, pk) + + # pack packed encrypted + ec_pack = evaluator.pack_squeeze(ec, pack_packed_num, offset_bit * 2, pk) + c_pack = sk.decrypt_to_encoded(ec_pack) + c = coder.unpack_floats(c_pack, offset_bit, pack_num * pack_packed_num, precision, 5) + assert torch.allclose(vec1 + vec2, c, rtol=1e-3, atol=1e-3) From 6cf438b96af90642ce304963d667c6f4b63c6c9a Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 13 Dec 2023 13:27:41 +0800 Subject: [PATCH 38/42] refactor config level Signed-off-by: sagewe --- configs/default.yaml | 92 +++++++++---------- python/fate/arch/protocol/mpc/__init__.py | 4 +- .../protocol/mpc/primitives/arithmetic.py | 6 +- .../arch/protocol/mpc/primitives/beaver.py | 2 +- 4 files changed, 50 insertions(+), 54 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index f0e303ba99..c3a20f4521 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,51 +1,3 @@ -communicator: - verbose: False -debug: - debug_mode: False - validation_mode: False -encoder: - precision_bits: 16 -functions: - max_method: "log_reduction" - - # exponential function - exp_iterations: 8 - - # reciprocal configuration - reciprocal_method: "NR" - reciprocal_nr_iters: 10 - reciprocal_log_iters: 1 - reciprocal_all_pos: False - reciprocal_initial: null - - # sqrt configuration - sqrt_nr_iters: 3 - sqrt_nr_initial: null - - # sigmoid / tanh configuration - sigmoid_tanh_method: "reciprocal" - sigmoid_tanh_terms: 32 - - # log configuration - log_iterations: 2 - log_exp_iterations: 8 - log_order: 8 - - # trigonometry configuration - trig_iterations: 10 - - # error function configuration: - erf_iterations: 8 -mpc: - active_security: False - provider: "TFP" - protocol: "beaver" -nn: - dpsmpc: - protocol: "layer_estimation" - skip_loss_forward: True - cache_pred_size: True - safety: serdes: # supported types: unrestricted, restricted, restricted_catch_miss @@ -68,3 +20,47 @@ safety: allow: True curve_type: - curve25519 + + mpc: + active_security: False + provider: "TFP" + protocol: "beaver" + functions: + max_method: "log_reduction" + + # exponential function + exp_iterations: 8 + + # reciprocal configuration + reciprocal_method: "NR" + reciprocal_nr_iters: 10 + reciprocal_log_iters: 1 + reciprocal_all_pos: False + reciprocal_initial: null + + # sqrt configuration + sqrt_nr_iters: 3 + sqrt_nr_initial: null + + # sigmoid / tanh configuration + sigmoid_tanh_method: "reciprocal" + sigmoid_tanh_terms: 32 + + # log configuration + log_iterations: 2 + log_exp_iterations: 8 + log_order: 8 + + # trigonometry configuration + trig_iterations: 10 + + # error function configuration: + erf_iterations: 8 + + communicator: + verbose: False + debug: + debug_mode: False + validation_mode: False + encoder: + precision_bits: 16 \ No newline at end of file diff --git a/python/fate/arch/protocol/mpc/__init__.py b/python/fate/arch/protocol/mpc/__init__.py index 443a5fa50a..1b84d94763 100644 --- a/python/fate/arch/protocol/mpc/__init__.py +++ b/python/fate/arch/protocol/mpc/__init__.py @@ -48,7 +48,7 @@ def get_default_provider(): - return __SUPPORTED_PROVIDERS[cfg.mpc.provider] + return __SUPPORTED_PROVIDERS[cfg.safety.mpc.provider] def cryptensor(ctx, *args, cryptensor_type=None, **kwargs): @@ -387,7 +387,7 @@ def fill_cache(): def ttp_required(): - return cfg.mpc.provider == "TTP" + return cfg.safety.mpc.provider == "TTP" def _setup_prng(ctx: "Context", generators): diff --git a/python/fate/arch/protocol/mpc/primitives/arithmetic.py b/python/fate/arch/protocol/mpc/primitives/arithmetic.py index 389baed0ba..dacc840091 100644 --- a/python/fate/arch/protocol/mpc/primitives/arithmetic.py +++ b/python/fate/arch/protocol/mpc/primitives/arithmetic.py @@ -383,7 +383,7 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C result.encode_as_(y) result.share = getattr(result.share, op)(y.share) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] - protocol = globals()[cfg.mpc.protocol] + protocol = globals()[cfg.safety.mpc.protocol] tmp = getattr(protocol, op)(self._ctx, result, y, *args, **kwargs) result.share = tmp.share else: @@ -463,7 +463,7 @@ def div_(self, y): # Truncate protocol for dividing by public integers: if comm.get().get_world_size() > 2: - protocol = globals()[cfg.mpc.protocol] + protocol = globals()[cfg.safety.mpc.protocol] protocol.truncate(self, y) else: self.share = self.share.div_(y, rounding_mode="trunc") @@ -596,7 +596,7 @@ def neg(self): return self.clone().neg_() def square_(self): - protocol = globals()[cfg.mpc.protocol] + protocol = globals()[cfg.safety.mpc.protocol] self.share = protocol.square(self._ctx, self).div_(self.encoder.scale).share return self diff --git a/python/fate/arch/protocol/mpc/primitives/beaver.py b/python/fate/arch/protocol/mpc/primitives/beaver.py index 0d820c3bbe..9c07736a2d 100644 --- a/python/fate/arch/protocol/mpc/primitives/beaver.py +++ b/python/fate/arch/protocol/mpc/primitives/beaver.py @@ -53,7 +53,7 @@ def __beaver_protocol(ctx, op, x, y, *args, **kwargs): from .arithmetic import ArithmeticSharedTensor - if cfg.mpc.active_security: + if cfg.safety.mpc.active_security: """ Reference: "Multiparty Computation from Somewhat Homomorphic Encryption" Link: https://eprint.iacr.org/2011/535.pdf From 9454b01fd52e3a7d5bb288c759cd0f475099eb02 Mon Sep 17 00:00:00 2001 From: zhihuiwan <15779896112@163.com> Date: Wed, 13 Dec 2023 15:58:19 +0800 Subject: [PATCH 39/42] fix log Signed-off-by: zhihuiwan <15779896112@163.com> --- python/fate/components/core/spec/logger.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/fate/components/core/spec/logger.py b/python/fate/components/core/spec/logger.py index b6608e6d5b..c5e8900e25 100644 --- a/python/fate/components/core/spec/logger.py +++ b/python/fate/components/core/spec/logger.py @@ -62,4 +62,8 @@ def install(self, debug=False): root=dict(handlers=["console"], level="DEBUG"), disable_existing_loggers=False, ) + + for _name, _conf in self.config.get("handlers", {}).items(): + if _conf.get("filename"): + os.makedirs(os.path.dirname(_conf.get("filename")), exist_ok=True) logging.config.dictConfig(self.config) From 5e0d07e33d9a66a154d6c50e81508f29b0af7d3b Mon Sep 17 00:00:00 2001 From: sagewe Date: Wed, 13 Dec 2023 21:09:13 +0800 Subject: [PATCH 40/42] fix config Signed-off-by: sagewe --- python/fate/arch/launchers/context_helper.py | 2 +- .../protocol/mpc/communicator/communicator.py | 3 +- .../fate/arch/protocol/mpc/debug/__init__.py | 2 +- python/fate/arch/protocol/mpc/encoder.py | 2 +- .../protocol/mpc/functions/approximations.py | 32 +++++++++---------- .../arch/protocol/mpc/functions/maximum.py | 4 +-- python/fate/arch/protocol/mpc/mpc.py | 2 +- .../protocol/mpc/primitives/arithmetic.py | 2 +- 8 files changed, 25 insertions(+), 24 deletions(-) diff --git a/python/fate/arch/launchers/context_helper.py b/python/fate/arch/launchers/context_helper.py index 4a4fd9e487..2837d76ad2 100644 --- a/python/fate/arch/launchers/context_helper.py +++ b/python/fate/arch/launchers/context_helper.py @@ -21,7 +21,7 @@ class LauncherClusterContextArgs: csession_id: str = field(default=None) federation_address: str = field(default="127.0.0.1:9377") cluster_address: str = field(default="127.0.0.1:4670") - federation_mode: str = field(default="message_queue") + federation_mode: str = field(default="stream") @dataclass diff --git a/python/fate/arch/protocol/mpc/communicator/communicator.py b/python/fate/arch/protocol/mpc/communicator/communicator.py index f12580f537..40039fddbe 100644 --- a/python/fate/arch/protocol/mpc/communicator/communicator.py +++ b/python/fate/arch/protocol/mpc/communicator/communicator.py @@ -369,6 +369,7 @@ def _get_parties(self, parties, namespace: NS): return Parties( self.ctx, self.ctx.federation, + self.ctx.computing, [(i, p) for i, p in enumerate(parties)], namespace, ) @@ -464,7 +465,7 @@ def logging_wrapper(self, *args, **kwargs): return args[0] # only log communication if needed: - if cfg.communicator.verbose: + if cfg.safety.mpc.communicator.verbose: rank = self.get_rank() _log = self._log_communication diff --git a/python/fate/arch/protocol/mpc/debug/__init__.py b/python/fate/arch/protocol/mpc/debug/__init__.py index d2284537b2..8d5d7abac0 100644 --- a/python/fate/arch/protocol/mpc/debug/__init__.py +++ b/python/fate/arch/protocol/mpc/debug/__init__.py @@ -22,7 +22,7 @@ def validate_attribute(self, name): # Get dispatched function call function = getattr_function(self, name) - if not cfg.debug.validation_mode: + if not cfg.safety.mpc.debug.validation_mode: return function # Run validation diff --git a/python/fate/arch/protocol/mpc/encoder.py b/python/fate/arch/protocol/mpc/encoder.py index 784134d755..75c158507e 100644 --- a/python/fate/arch/protocol/mpc/encoder.py +++ b/python/fate/arch/protocol/mpc/encoder.py @@ -35,7 +35,7 @@ class FixedPointEncoder: def __init__(self, precision_bits=None): if precision_bits is None: - precision_bits = cfg.encoder.precision_bits + precision_bits = cfg.safety.mpc.encoder.precision_bits self._precision_bits = precision_bits self._scale = int(2**precision_bits) diff --git a/python/fate/arch/protocol/mpc/functions/approximations.py b/python/fate/arch/protocol/mpc/functions/approximations.py index 8686cc2735..3aaab482c7 100644 --- a/python/fate/arch/protocol/mpc/functions/approximations.py +++ b/python/fate/arch/protocol/mpc/functions/approximations.py @@ -45,7 +45,7 @@ def exp(self): Set the number of iterations for the limit approximation with config.exp_iterations. """ # noqa: W605 - iters = cfg.functions.exp_iterations + itecs = cfg.safety.mpc.functions.exp_iterations result = 1 + self.div(2**iters) for _ in range(iters): @@ -89,9 +89,9 @@ def log(self, input_in_01=False): # Initialization to a decent estimate (found by qualitative inspection): # ln(x) = x/120 - 20exp(-2x - 1.0) + 3.0 - iterations = cfg.functions.log_iterations - exp_iterations = cfg.functions.log_exp_iterations - order = cfg.functions.log_order + iterations = cfg.safety.mpc.functions.log_iterations + exp_iterations = cfg.safety.mpc.functions.log_exp_iterations + order = cfg.safety.mpc.functions.log_order term1 = self.div(120) term2 = exp(self.mul(2).add(1.0).neg()).mul(20) @@ -144,9 +144,9 @@ def reciprocal(self, input_in_01=False): return rec # Get config options - method = cfg.functions.reciprocal_method - all_pos = cfg.functions.reciprocal_all_pos - initial = cfg.functions.reciprocal_initial + method = cfg.safety.mpc.functions.reciprocal_method + all_pos = cfg.safety.mpc.functions.reciprocal_all_pos + initial = cfg.safety.mpc.functions.reciprocal_initial if not all_pos: sgn = self.sign() @@ -155,7 +155,7 @@ def reciprocal(self, input_in_01=False): return sgn * reciprocal(pos) if method == "NR": - nr_iters = cfg.functions.reciprocal_nr_iters + nr_iters = cfg.safety.mpc.functions.reciprocal_nr_iters if initial is None: # Initialization to a decent estimate (found by qualitative inspection): # 1/x = 3exp(1 - 2x) + 0.003 @@ -169,7 +169,7 @@ def reciprocal(self, input_in_01=False): result = 2 * result - result * result * self return result elif method == "log": - log_iters = cfg.functions.reciprocal_log_iters + log_iters = cfg.safety.mpc.functions.reciprocal_log_iters with cfg.temp_override({"functions.log_iters": log_iters}): return exp(-log(self)) else: @@ -189,8 +189,8 @@ def inv_sqrt(self): .. _Newton-Raphson: https://en.wikipedia.org/wiki/Fast_inverse_square_root#Newton's_method """ - initial = cfg.functions.sqrt_nr_initial - iters = cfg.functions.sqrt_nr_iters + initial = cfg.safety.mpc.functions.sqrt_nr_initial + iters = cfg.safety.mpc.functions.sqrt_nr_iters # Initialize using decent approximation if initial is None: @@ -226,7 +226,7 @@ def _eix(self): r"""Computes e^(i * self) where i is the imaginary unit. Returns (Re{e^(i * self)}, Im{e^(i * self)} = cos(self), sin(self) """ - iterations = cfg.functions.trig_iterations + iterations = cfg.safety.mpc.functions.trig_iterations re = 1 im = self.div(2**iterations) @@ -293,7 +293,7 @@ def sigmoid(self): the reciprocal """ # noqa: W605 - method = cfg.functions.sigmoid_tanh_method + method = cfg.safety.mpc.functions.sigmoid_tanh_method if method == "chebyshev": tanh_approx = tanh(self.div(2)) @@ -344,12 +344,12 @@ def tanh(self): terms (int): highest degree of Chebyshev polynomials. Must be even and at least 6. """ - method = cfg.functions.sigmoid_tanh_method + method = cfg.safety.mpc.functions.sigmoid_tanh_method if method == "reciprocal": return self.mul(2).sigmoid().mul(2).sub(1) elif method == "chebyshev": - terms = cfg.functions.sigmoid_tanh_terms + terms = cfg.safety.mpc.functions.sigmoid_tanh_terms coeffs = chebyshev_series(torch.tanh, 1, terms)[1::2] tanh_polys = _chebyshev_polynomials(self, terms) tanh_polys_flipped = tanh_polys.unsqueeze(dim=-1).transpose(0, -1).squeeze(dim=0) @@ -395,7 +395,7 @@ def erf(tensor): r""" Approximates the error function of the input tensor using a Taylor approximation. """ - iters = cfg.functions.erf_iterations + iters = cfg.safety.mpc.functions.erf_iterations output = tensor.clone() for n in range(1, iters + 1): diff --git a/python/fate/arch/protocol/mpc/functions/maximum.py b/python/fate/arch/protocol/mpc/functions/maximum.py index e892a97ebc..fd5e345e28 100644 --- a/python/fate/arch/protocol/mpc/functions/maximum.py +++ b/python/fate/arch/protocol/mpc/functions/maximum.py @@ -24,7 +24,7 @@ def argmax(self, dim=None, keepdim=False, one_hot=True): """Returns the indices of the maximum value of all elements in the `input` tensor. """ - method = cfg.functions.max_method + method = cfg.safety.mpc.functions.max_method if self.dim() == 0: result = ( @@ -50,7 +50,7 @@ def argmin(self, dim=None, keepdim=False, one_hot=True): def max(self, dim=None, keepdim=False, one_hot=True): """Returns the maximum value of all elements in the input tensor.""" - method = cfg.functions.max_method + method = cfg.safety.mpc.functions.max_method if dim is None: if method in ["log_reduction", "double_log_reduction"]: # max_result can be obtained directly diff --git a/python/fate/arch/protocol/mpc/mpc.py b/python/fate/arch/protocol/mpc/mpc.py index f69b7428ef..e2b034ecf6 100644 --- a/python/fate/arch/protocol/mpc/mpc.py +++ b/python/fate/arch/protocol/mpc/mpc.py @@ -196,7 +196,7 @@ def reveal(self, dst=None, group=None): def __repr__(self): """Returns a representation of the tensor useful for debugging.""" - debug_mode = cfg.debug.debug_mode + debug_mode = cfg.safety.mpc.debug.debug_mode share = self.share plain_text = self._tensor.get_plain_text() if debug_mode else "HIDDEN" diff --git a/python/fate/arch/protocol/mpc/primitives/arithmetic.py b/python/fate/arch/protocol/mpc/primitives/arithmetic.py index dacc840091..b09b17380a 100644 --- a/python/fate/arch/protocol/mpc/primitives/arithmetic.py +++ b/python/fate/arch/protocol/mpc/primitives/arithmetic.py @@ -455,7 +455,7 @@ def div_(self, y): y = y.long() if isinstance(y, int) or is_int_tensor(y): - validate = cfg.debug.validation_mode + validate = cfg.safety.mpc.debug.validation_mode if validate: tolerance = 1.0 From fbdfe247939ac26c3ed8b579cab69d989551c253 Mon Sep 17 00:00:00 2001 From: Xiongli <740332065@qq.com> Date: Wed, 13 Dec 2023 21:56:41 +0800 Subject: [PATCH 41/42] merge rc-new Signed-off-by: Xiongli <740332065@qq.com> --- java/osx/bin/service.sh | 1 + java/osx/osx-broker/pom.xml | 2 +- .../java/org/fedai/osx/broker/Bootstrap.java | 4 -- .../eggroll/PutBatchSinkPushRespSO.java | 4 +- .../grpc/QueuePushReqStreamObserver.java | 29 +++----- .../osx/broker/http/HttpsClientPool.java | 2 +- .../osx/broker/provider/FateTechProvider.java | 5 -- .../router/DefaultFateRouterServiceImpl.java | 8 ++- .../broker/router/RouterTableAddService.java | 1 + .../fedai/osx/broker/server/OsxServer.java | 59 ++++++++++------ .../main/resources/broker/broker.properties | 14 ++-- .../src/test/resources/cert4test/ca.crt | 29 -------- .../src/test/resources/cert4test/ca.key | 54 -------------- .../src/test/resources/cert4test/client.crt | 27 ------- .../src/test/resources/cert4test/client.csr | 26 ------- .../src/test/resources/cert4test/client.key | 51 -------------- .../src/test/resources/cert4test/client.pem | 52 -------------- .../src/test/resources/cert4test/server.crt | 27 ------- .../src/test/resources/cert4test/server.csr | 26 ------- .../src/test/resources/cert4test/server.key | 51 -------------- .../src/test/resources/cert4test/server.pem | 52 -------------- .../test/resources/keystore/client/client.cer | 20 ------ .../resources/keystore/client/identity.jks | Bin 2551 -> 0 bytes .../resources/keystore/client/truststore.jks | Bin 976 -> 0 bytes .../resources/keystore/server/identity.jks | Bin 2607 -> 0 bytes .../test/resources/keystore/server/server.cer | 22 ------ .../resources/keystore/server/truststore.jks | Bin 922 -> 0 bytes java/osx/osx-core/pom.xml | 24 ++++--- .../org/fedai/osx/core/config/MetaInfo.java | 66 ++++++++++-------- .../org/fedai/osx/core/constant/Dict.java | 2 + .../org/fedai/osx/core/router/RouterInfo.java | 37 ++++++++-- .../fedai/osx/core/utils/OSXCertUtils.java | 25 ++++--- .../osx/core/utils/OsxX509TrustManager.java | 3 +- java/osx/pom.xml | 64 ++++------------- 34 files changed, 188 insertions(+), 599 deletions(-) delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/ca.crt delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/ca.key delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/client.crt delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/client.csr delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/client.key delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/client.pem delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/server.crt delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/server.csr delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/server.key delete mode 100644 java/osx/osx-broker/src/test/resources/cert4test/server.pem delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/client/client.cer delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/client/identity.jks delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/client/truststore.jks delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/server/identity.jks delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/server/server.cer delete mode 100644 java/osx/osx-broker/src/test/resources/keystore/server/truststore.jks diff --git a/java/osx/bin/service.sh b/java/osx/bin/service.sh index 44da5df35c..219c6eb944 100644 --- a/java/osx/bin/service.sh +++ b/java/osx/bin/service.sh @@ -62,3 +62,4 @@ case "$1" in exit 1 esac + diff --git a/java/osx/osx-broker/pom.xml b/java/osx/osx-broker/pom.xml index fbd9c01c20..674a3aaabd 100644 --- a/java/osx/osx-broker/pom.xml +++ b/java/osx/osx-broker/pom.xml @@ -26,7 +26,7 @@ com.squareup.okhttp3 okhttp - 4.11.0 + 4.12.0 diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java index 81482e66b2..4f855207ee 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/Bootstrap.java @@ -62,10 +62,6 @@ public static void main(String[] args) { packages.add(Bootstrap.class.getPackage().getName()); ApplicationStartedRunnerUtils.run(injector, packages, args); - boolean startOk = injector.getInstance(OsxServer.class).start(); - if (!startOk) { - System.exit(-1); - } Thread shutDownThread = new Thread(bootstrap::stop); Runtime.getRuntime().addShutdownHook(shutDownThread); synchronized (lockObject) { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java index a5003bc80a..db58fd585d 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/eggroll/PutBatchSinkPushRespSO.java @@ -18,12 +18,14 @@ import com.webank.ai.eggroll.api.networking.proxy.Proxy; import com.webank.eggroll.core.transfer.Transfer; import io.grpc.stub.StreamObserver; +import org.fedai.osx.core.config.MetaInfo; import org.fedai.osx.core.router.RouterInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; public class PutBatchSinkPushRespSO implements StreamObserver { @@ -49,7 +51,7 @@ public PutBatchSinkPushRespSO(Proxy.Metadata reqHeader, @Override public void onNext(Transfer.TransferBatch resp) { try { - commandFuture.get(); + commandFuture.get(MetaInfo.BATCH_SINK_PUSH_EXECUTOR_TIMEOUT, TimeUnit.MILLISECONDS); eggSiteServicerPushRespSO.onNext(reqHeader.toBuilder().setAck(resp.getHeader().getId()).build()); eggSiteServicerPushRespSO.onCompleted(); } catch (Exception e) { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java index b468c01f0e..d2e6a69406 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/grpc/QueuePushReqStreamObserver.java @@ -52,7 +52,7 @@ import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; -public class QueuePushReqStreamObserver implements StreamObserver { +public class QueuePushReqStreamObserver implements StreamObserver { static public ConcurrentHashMap queueIdMap = new ConcurrentHashMap<>(); static AtomicInteger seq = new AtomicInteger(0); @@ -112,8 +112,8 @@ public void init(Proxy.Packet packet) throws Exception { if (!isDst) { routerInfo = routerService.route(context.getSrcNodeId(), context.getSrcComponent(), context.getDesNodeId(), context.getDesComponent()); if (routerInfo == null) { - logger.error("no router info is found for party id {}",context.getDesNodeId()); - throw new NoRouterInfoException("no router is found for party id"+context.getDesNodeId()); + logger.error("no router info is found for party id {}", context.getDesNodeId()); + throw new NoRouterInfoException("no router is found for party id" + context.getDesNodeId()); } } if (isDst) { @@ -128,7 +128,7 @@ public void init(Proxy.Packet packet) throws Exception { context.setSrcNodeId(routerInfo.getSourcePartyId()); context.setDesNodeId(routerInfo.getDesPartyId()); if (routerInfo.getProtocol().equals(Protocol.http)) { - logger.error("invalid router info {}, grpc stream is not support http1.x",routerInfo); + logger.error("invalid router info {}, grpc stream is not support http1.x", routerInfo); throw new SysException("invalid router info for grpc stream"); } else { ManagedChannel managedChannel = GrpcConnectionFactory.createManagedChannel(context.getRouterInfo()); @@ -156,7 +156,7 @@ private void mockEggroll(OsxContext context, Proxy.Packet firstRequest) { TransferServiceGrpc.TransferServiceStub stub = TransferServiceGrpc.newStub(channel); CompletableFuture commandFuture = new CompletableFuture<>(); commandFuture.complete(new ErTask()); - putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, backRespSO, finishLatch,routerInfo)); + putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, backRespSO, finishLatch, routerInfo)); } private void initEggroll(OsxContext context, Proxy.Packet firstRequest) { @@ -220,17 +220,10 @@ private void initEggroll(OsxContext context, Proxy.Packet firstRequest) { job); Future commandFuture = RollPairContext.executor.submit(() -> { - try { - CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); - Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); - Meta.Task taskMeta = Meta.Task.parseFrom(commandResponse.getResultsList().get(0)); - ErTask erTask = ErTask.parseFromPb(taskMeta); - long now = System.currentTimeMillis(); - return erTask; - } catch (Exception e) { - logger.error("submit putBatch task error", e); - throw e; - } + CommandClient commandClient = new CommandClient(egg.getCommandEndpoint()); + Command.CommandResponse commandResponse = commandClient.call(RollPair.EGG_RUN_TASK_COMMAND, task); + Meta.Task taskMeta = Meta.Task.parseFrom(commandResponse.getResultsList().get(0)); + return ErTask.parseFromPb(taskMeta); }); RouterInfo routerInfo = new RouterInfo(); routerInfo.setProtocol(Protocol.grpc); @@ -241,14 +234,14 @@ private void initEggroll(OsxContext context, Proxy.Packet firstRequest) { context.setDesNodeId(routerInfo.getDesPartyId()); ManagedChannel channel = GrpcConnectionFactory.createManagedChannel(routerInfo); TransferServiceGrpc.TransferServiceStub stub = TransferServiceGrpc.newStub(channel); - putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, backRespSO, finishLatch,routerInfo)); + putBatchSinkPushReqSO = stub.send(new PutBatchSinkPushRespSO(metadata, commandFuture, backRespSO, finishLatch, routerInfo)); } @Override public void onNext(Proxy.Packet value) { try { - if(value.getHeader()!=null){ + if (value.getHeader() != null) { context.setTraceId(Long.toString(value.getHeader().getSeq())); } // long seq = value.getHeader().getSeq(); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java index af0508d05e..19d6a82de9 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/http/HttpsClientPool.java @@ -323,7 +323,7 @@ private static SSLSocketFactory getSslFactory(String caPath, String clientCertPa TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; // Load client certificate KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); // Initialize the factory return sslContext.getSocketFactory(); diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/provider/FateTechProvider.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/provider/FateTechProvider.java index f9965c527b..6b46792d0c 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/provider/FateTechProvider.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/provider/FateTechProvider.java @@ -311,13 +311,11 @@ private Object handleInvoke(OsxContext context, Object request,boolean interInv private ExceptionInfo handleExceptionInfo(OsxContext context, Throwable e) { ExceptionInfo exceptionInfo = ErrorMessageUtil.handleExceptionExceptionInfo(context, e); - //this.writeHttpRespose(response, exceptionInfo.getCode(),exceptionInfo.getMessage(),null); context.setReturnCode(exceptionInfo.getCode()); context.setReturnMsg(exceptionInfo.getMessage()); return exceptionInfo; } - @Override public void processGrpcPeek(OsxContext context, Osx.PeekInbound inbound, StreamObserver responseObserver) { context.setProtocol(Protocol.grpc); @@ -372,9 +370,6 @@ public void processGrpcPush(OsxContext context, Osx.PushInbound inbound, StreamO if (MetaInfo.PROPERTY_SELF_PARTY.contains(desNodeId)) { ServiceRegisterInfo serviceRegisterInfo = this.serviceRegisterManager.getServiceWithLoadBalance(context, "", UriConstants.PUSH, false); AbstractServiceAdaptorNew serviceAdaptor = serviceRegisterInfo.getServiceAdaptor(); -// ProduceRequest produceRequest = new ProduceRequest(); -// produceRequest.setPayload(inbound.getPayload().toByteArray()); -// produceRequest.setTopic(inbound.getTopic()); ProduceRequest produceRequest = buildProduceRequestFromGrpc(inbound); ProduceResponse produceResponse = (ProduceResponse) serviceAdaptor.service(context, produceRequest); if (produceResponse != null) { diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java index f862264208..834e1445c6 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/DefaultFateRouterServiceImpl.java @@ -154,6 +154,7 @@ private RouterInfo buildRouterInfo(Map endpoint, String srcPartyId, String srcRo routerInfo.setProtocol(protocol); routerInfo.setUrl(endpoint.get(Dict.URL) != null ? endpoint.get(Dict.URL).toString() : ""); routerInfo.setUseSSL(endpoint.get(Dict.USE_SSL) != null && Boolean.parseBoolean(endpoint.get(Dict.USE_SSL).toString())); + routerInfo.setUseKeyStore(endpoint.get(Dict.USE_KEYSTORE) != null && Boolean.parseBoolean(endpoint.get(Dict.USE_KEYSTORE).toString())); routerInfo.setCaFile(endpoint.get(Dict.CA_FILE) != null ? endpoint.get(Dict.CA_FILE).toString() : ""); routerInfo.setCertChainFile(endpoint.get(Dict.CERT_CHAIN_FILE) != null ? endpoint.get(Dict.CERT_CHAIN_FILE).toString() : ""); routerInfo.setPrivateKeyFile(endpoint.get(Dict.PRIVATE_KEY_FILE) != null ? endpoint.get(Dict.PRIVATE_KEY_FILE).toString() : ""); @@ -378,6 +379,11 @@ private boolean checkCycle(String ip, int port) { return cycle; } + @Override + public int getRunnerSequenceId() { + return Integer.MAX_VALUE; + } + @Override public void run(String[] args) throws Exception { this.start(); @@ -433,7 +439,7 @@ private void validateRouterInfo(RouterInfo routerInfo){ Preconditions.checkArgument(routerInfo!=null); String desPartyId = routerInfo.getDesPartyId(); Preconditions.checkArgument(StringUtils.isNotEmpty(desPartyId),"des party id is null"); - if(routerInfo.getProtocol()!=null||Protocol.grpc.equals(routerInfo.getProtocol())){ + if(routerInfo.getProtocol()==null || Protocol.grpc.equals(routerInfo.getProtocol())){ Preconditions.checkArgument(StringUtils.isNotEmpty(routerInfo.getHost()), "route_table.json "+desPartyId+" host/ip is null"); Preconditions.checkArgument(routerInfo.getPort()!=null, "route_table.json "+desPartyId+" port is null"); } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterTableAddService.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterTableAddService.java index 89975a51d5..74cedb971e 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterTableAddService.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/router/RouterTableAddService.java @@ -88,6 +88,7 @@ public RouterAddRequest decode(Object object) { if(result==null){ throw new ParameterException("invalid param for router operation"); } + return result; } throw new ParameterException("invalid param for router operation"); } diff --git a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java index 24887dd732..6fe6b6b497 100644 --- a/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java +++ b/java/osx/osx-broker/src/main/java/org/fedai/osx/broker/server/OsxServer.java @@ -41,6 +41,7 @@ import org.fedai.osx.broker.http.InterServlet; import org.fedai.osx.core.config.MetaInfo; import org.fedai.osx.core.frame.ContextPrepareInterceptor; +import org.fedai.osx.core.service.ApplicationStartedRunner; import org.fedai.osx.core.utils.OSXCertUtils; import org.fedai.osx.core.utils.OsxX509TrustManager; @@ -66,7 +67,7 @@ */ @Singleton @Slf4j -public class OsxServer { +public class OsxServer implements ApplicationStartedRunner { io.grpc.Server server; io.grpc.Server tlsServer; org.eclipse.jetty.server.Server httpServer; @@ -143,24 +144,24 @@ public Server buildHttpsServer() { ServerConnector connector; SslContextFactory.Server sslServer = new SslContextFactory.Server(); // //如果PROPERTY_HTTP_SSL_TRUST_STORE_PATH 为空, 则去读取证书套件,然后生成一个TRUST_STORE - if (StringUtils.isNotBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH)) { + if (StringUtils.isNotBlank(MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE) && StringUtils.isBlank(MetaInfo.PROPERTY_HTTPS_SERVER_PRIVATE_KEY_FILE)) { sslServer.setTrustStoreType(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE.toUpperCase()); - sslServer.setTrustStorePath(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH); - sslServer.setTrustStore(OSXCertUtils.getTrustStore(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PATH, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE)); - if (StringUtils.isAllBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)) { + sslServer.setTrustStorePath(MetaInfo.PROPERTY_HTTPS_SERVER_TRUST_KEYSTORE_FILE); + sslServer.setTrustStore(OSXCertUtils.getTrustStore(MetaInfo.PROPERTY_HTTPS_SERVER_TRUST_KEYSTORE_FILE, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_TYPE)); + if (StringUtils.isAllBlank(MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD, MetaInfo.PROPERTY_HTTPS_SERVER_TRUST_FILE_PASSWORD)) { throw new IllegalArgumentException("http.ssl.key.store.password/http.ssl.trust.store.password is not set,please check config file"); } - sslServer.setTrustStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD)); - sslServer.setKeyStorePath(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PATH); - sslServer.setKeyStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD, MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD)); + sslServer.setTrustStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTPS_SERVER_TRUST_FILE_PASSWORD, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD)); + sslServer.setKeyStorePath(MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE); + sslServer.setKeyStorePassword(StringUtils.firstNonBlank(MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD, MetaInfo.PROPERTY_HTTPS_SERVER_TRUST_FILE_PASSWORD)); sslServer.setTrustStoreProvider(MetaInfo.PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER); } else { SSLContext sslContext = SSLContext.getInstance("TLS"); - KeyStore keyStore = OSXCertUtils.getKeyStore(MetaInfo.PROPERTY_SERVER_CA_FILE, MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE, MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE); + KeyStore keyStore = OSXCertUtils.getKeyStore2(MetaInfo.PROPERTY_HTTPS_SERVER_CA_FILE, MetaInfo.PROPERTY_HTTPS_SERVER_CERT_CHAIN_FILE, MetaInfo.PROPERTY_HTTPS_SERVER_PRIVATE_KEY_FILE); TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; // Load client certificate KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); sslServer.setSslContext(sslContext); } @@ -183,7 +184,7 @@ public Server buildHttpsServer() { return null; } - ServletContextHandler buildServlet(HttpServlet servlet) { + ServletContextHandler buildServlet(HttpServlet servlet) { ServletContextHandler context = new ServletContextHandler(); context.setContextPath(MetaInfo.PROPERTY_HTTP_CONTEXT_PATH); ServletHolder servletHolder = new ServletHolder(servlet); @@ -252,17 +253,17 @@ public boolean start() { } private io.grpc.Server buildTlsServer() { - String serverCertChainFile = MetaInfo.PROPERTY_SERVER_CERT_CHAIN_FILE; - String privateKeyFilePath = MetaInfo.PROPERTY_SERVER_PRIVATE_KEY_FILE; - String serverCaFilePath = MetaInfo.PROPERTY_SERVER_CA_FILE; + String serverCertChainFile = MetaInfo.PROPERTY_GRPC_SERVER_CERT_CHAIN_FILE; + String privateKeyFilePath = MetaInfo.PROPERTY_GRPC_SERVER_PRIVATE_KEY_FILE; + String serverCaFilePath = MetaInfo.PROPERTY_GRPC_SERVER_CA_FILE; // Define the JKS file and its password - String keyJksFilePath = MetaInfo.PROPERTY_SERVER_KEYSTORE_FILE; - String keyJksPassword = MetaInfo.PROPERTY_SERVER_KEYSTORE_FILE_PASSWORD; + String keyJksFilePath = MetaInfo.PROPERTY_GRPC_SERVER_KEYSTORE_FILE; + String keyJksPassword = MetaInfo.PROPERTY_GRPC_SERVER_KEYSTORE_FILE_PASSWORD; // Define the JKS file and its password - String trustFilePath = MetaInfo.PROPERTY_SERVER_TRUST_KEYSTORE_FILE; - String trustJksPassword = MetaInfo.PROPERTY_SERVER_TRUST_FILE_PASSWORD; + String trustFilePath = MetaInfo.PROPERTY_GRPC_SERVER_TRUST_KEYSTORE_FILE; + String trustJksPassword = MetaInfo.PROPERTY_GRPC_SERVER_TRUST_FILE_PASSWORD; if (PROPERTY_OPEN_GRPC_TLS_SERVER) { try { @@ -270,7 +271,7 @@ private io.grpc.Server buildTlsServer() { NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address); SslContextBuilder sslContextBuilder = null; - if (StringUtils.isNotBlank(PROPERTY_SERVER_KEYSTORE_FILE)) { + if (StringUtils.isNotBlank(PROPERTY_GRPC_SERVER_KEYSTORE_FILE)) { // Load the truststore file KeyStore trustStore = loadKeyStore(trustFilePath, trustJksPassword); // Create a TrustManagerFactory and initialize it with the truststore @@ -287,7 +288,7 @@ private io.grpc.Server buildTlsServer() { .trustManager(trustManagerFactory) .sessionTimeout(MetaInfo.PROPERTY_GRPC_SSL_SESSION_TIME_OUT) .sessionCacheSize(MetaInfo.PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE); - if(PROPERTY_GRPC_SSL_OPEN_CLIENT_VALIDATE){ + if (PROPERTY_GRPC_SSL_OPEN_CLIENT_VALIDATE) { sslContextBuilder.clientAuth(ClientAuth.REQUIRE); } @@ -296,14 +297,13 @@ private io.grpc.Server buildTlsServer() { .sessionTimeout(MetaInfo.PROPERTY_GRPC_SSL_SESSION_TIME_OUT) .sessionCacheSize(MetaInfo.PROPERTY_HTTP_SSL_SESSION_CACHE_SIZE); - if(PROPERTY_GRPC_SSL_OPEN_CLIENT_VALIDATE){ - Preconditions.checkArgument(StringUtils.isNotEmpty(serverCaFilePath),"config server.ca.file is null"); + if (PROPERTY_GRPC_SSL_OPEN_CLIENT_VALIDATE) { + Preconditions.checkArgument(StringUtils.isNotEmpty(serverCaFilePath), "config server.ca.file is null"); sslContextBuilder.clientAuth(ClientAuth.REQUIRE).trustManager(new File(serverCaFilePath)); } } - log.info("running in secure mode. server crt path: {}, server key path: {}, ca crt path: {}.", serverCertChainFile, privateKeyFilePath, serverCaFilePath); //serverBuilder.executor(executor); @@ -371,4 +371,17 @@ private io.grpc.Server buildServer() { return nettyServerBuilder.build(); } + @Override + public int getRunnerSequenceId() { + return Integer.MAX_VALUE - 1; + } + + @Override + public void run(String[] args) { + boolean startOk = this.start(); + if (!startOk) { + log.error("osx server start failed"); + System.exit(-1); + } + } } diff --git a/java/osx/osx-broker/src/main/resources/broker/broker.properties b/java/osx/osx-broker/src/main/resources/broker/broker.properties index 29308e0a2d..db43ee98f7 100644 --- a/java/osx/osx-broker/src/main/resources/broker/broker.properties +++ b/java/osx/osx-broker/src/main/resources/broker/broker.properties @@ -1,15 +1,15 @@ -grpc.port= 9377 +grpc.port= 9370 self.party=9999 # the IP of the cluster manager component of eggroll -eggroll.cluster.manager.ip = localhost +eggroll.cluster.manager.ip = 127.0.0. # the port of the cluster manager component of eggroll eggroll.cluster.manager.port = 4670 -open.http.server = true -open.grpc.tls.server=true -server.cert.chain.file=/Users/kaideng/work/cert/test/server.crt -server.private.key.file=/Users/kaideng/work/cert/test/server.pem -#server.ca.file=/Users/kaideng/work/cert/test/ca.crt +open.http.server = false +open.grpc.tls.server=false +server.cert.chain.file= +server.private.key.file= +#server.ca.file= diff --git a/java/osx/osx-broker/src/test/resources/cert4test/ca.crt b/java/osx/osx-broker/src/test/resources/cert4test/ca.crt deleted file mode 100644 index 6f53893d77..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/ca.crt +++ /dev/null @@ -1,29 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFCTCCAvGgAwIBAgIUSk3KOVV1sMXVINNd4DJMd/H6bTgwDQYJKoZIhvcNAQEL -BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIzMTEyMDA4MzAyOFoXDTI0MTEx -OTA4MzAyOFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF -AAOCAg8AMIICCgKCAgEA0IMy7kx4i68onPFcPWx9ymTCjrhZzkRpp+5oRImnVKwh -VIjEtYIwKhK33m887/7ZSaR32+WohtuLkGwSAbkL2yozct00lhHoFJVRvdLBe0S+ -+355UtDoMgBEE/f5336N5gp6IDbtelKdamvUjw02YzIj9o4B8CpZ3PJlLXlXAQuh -nWxpMEDw8CnSxVTG/u1uh3/Ln5wHOIKyAjsr1Wj4eqtpk5Ppo/hIx8GIXGOXp+i/ -oZXi86K8GymsjzAGRWTU03xs30pfLuqAnOcqaad5pKwECpRa95qiLnWiPvKdm9iw -bLZYKZpq5XsY7TetOP77YTtB65lCOhWsJMrljqpzoQ0qsHXkW4Fna7XCErKEqBIx -CsVhugtX34iR5OPuOllDSx4eqLhhz+76KxuNS986WJLOfYgKiYJkKi1qW52pmjH8 -8/p/YjqKyBBXpGr1RpRtu6oFnznFcdLejaeqM7jX6jFLh0C6Z4qDMiHxa48Zl82T -mVfc6DbWdQfbR+cTnCWbPKtwcffAicIxCuyiqVF0f0nTTxHkHpe8dcgYEDpuC5My -XdTXg0X61f9Va5cJ2kg4YF8fNvr2sKKsx5oJrE1SFT8BicIo/gFXtYQBnNcNX4Z9 -YAkP2oMPSWohAfSsJxUvH/aBghn3mJc4pgkquNveIbnPWLZkZcP7lCIFB4oZ5AcC -AwEAAaNTMFEwHQYDVR0OBBYEFFCTno/81Bur+9PmHB7/4naecIo8MB8GA1UdIwQY -MBaAFFCTno/81Bur+9PmHB7/4naecIo8MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZI -hvcNAQELBQADggIBAAVm5oo+GaQ/9YtcnzY5Uk223lokvicmK3TUxDqjfzP31Spp -sGHw+QBS99M0qSu1jZhx0fb1K3gxWHgtroAFAjr0P1nWkKHlDDJDXj44ql4XshsW -3ZhR1rHI6JZsnygCQHN+ik7elXzj3e3p49TVruvHIau2xqVHJSvB+3pi4ekML39n -dg3aPC2GDVqgNKBgIWYQhE/PiXd2wqxM1MtTpLXNTRdd3KWPVNkg6BBxXFmDhH1w -zcUknr/xgtWwCb3XxajvM0Bx+id1LeE7yFsNH50v4tm4zyrG2h6bU16bQrVtu0pb -MZ8OBrT7SVEH0lXDO7HvCPhFQrSUj41+lhdjyljw25yFbVOsANOQoD3df+S8XM4V -RoiBIRJ1bwWm3XFnK24a/o+MngV3Jc3rlXw/Zfqvs6/2dc4yAAajRpA2G2QfYVLT -k0tqlSNYXi/VHjdatDeFYJ/X/FMLWfIwRgwa75HdHxNl9n8eQoAtB7UTKIoBk+vl -nhxOYhvjnWaDRVCdlkWBzLd3OXCbD1Ddt0fVNfQrv064vZlRdLHJuOmbga3qoWcF -nKTFGVBwb9ffznW6GD7DSLadonhPWbScUOE9H9as6d758b2hU5u2gsjj0JMdRZaB -Tnmg2+WRsEMr/D/atxUppyIRNUCpbTAr3wncYdzoOCfA9p5z9wSYHGtWqRx4 ------END CERTIFICATE----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/ca.key b/java/osx/osx-broker/src/test/resources/cert4test/ca.key deleted file mode 100644 index 12c7ead7c2..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/ca.key +++ /dev/null @@ -1,54 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: DES-EDE3-CBC,ADFE46050AB4B6E9 - -4OBk8OqMUhcMEfcbChM/0YKgZU0OY6c0y6Fm4yM8XSXBQ4wuOtfNdcE5iBqdTGli -m4JXnlbC2U8ryWOopKa3FA+lhie6tYlHqT6IRRYDlXAFDVDstJlEjZlxAXOvBY8H -9p9mRVlGPiIyj8uf4Zfi9aIPGxyX+lLk8BGwkxeKv2BsRzaalX6t5YYNCHo5sQFa -Z/VECjIiGxbqQbR34oDJ/bVQWbkhn5tl33o0Ng+dItfunYRX6roZbpIwVwSKZl6s -4kpgCFP5PH4jPSLCs4j93ujqWPLeQlOLJk0QfmZ3Wxl/TM/QtIUZFRTLWtl2oUff -yk60yf27nCKiyaPlf14P2vWWyFgz/Lk8w6ATkBMSlwRujEhTt5wbHwHydXSafcjl -mRk3CSYHE5CunrkaY+hUi/r2/VGaLA3Oz3F8PyZOlJmr5Bp8qD0nDzApb80gkdPY -eQ0OqQVyilHA0Qw/86h6BaEKk3HbCJf9+oHP5pswOVyp9+JJ23THH+ICvjxg1AZm -fAlrQocNF7LW+u20KupRgWmMbwvkDfJmplNaCXJI0PQFLr1unZTO1pfyEA4PKxa6 -eyT6Uy23CYvbhcKRePr0SptGCunO1GI/c2/x/P8jFxyH4ZnnNV7pjqUmmjNz2uFR -/uksAmGksbvTLgqmK1rloWdcWYwlCHecO0oq3yWQhUW49Z2DuGe6izMj8O/NzoLY -r/Ogx3LAVAHXSn02OHAkOyCUbzoqoi83vLqRCrpHp27qZazYm2m7P7s+iaTg33Kz -lB/nX54+MWSreOmGe8nLSu1BMXSvcg4iHP1YfNsrLwM9WQoZyM3uvHfc8I2CAVy1 -wc5QXLImvECwhTq3MbWMuEkOrKrCQ19dx/N5+WuzI/QpwSbuDuDsmg8cJIbcZJCV -ZPW533T8odXa5SNNThCWOG3Y+QI1MPdqNt8GNXHazt2fg2hJfxKd/yWg/+/w2qrs -IpoejgdHnoxZVXvcnqOU5+JeG0cGoJyBtii77SaNLck9pq5wIyuB85okl2RUGODr -sDY8FqJ/ABGDpqH4qqhXmLqfjVaXJ7fPme2BsE/rG4wR6lZkvT7qpTafN0BcQyiq -igIP3fmuFp76tQs0KiJcPguj0Osh/aoVhLJlCFU2xn/zDUtqZxiLH9gswpAVWQOc -nDUyMV938rjpQCDjKpXwKfsBkG1qNSpv6rzGjrYgpmMsQTQ6orfFBomv/sBPq7Zo -F3hbdaEH0VNjXIyK4o2HhUkEHWcyqaEhYepYkoP8t+YIXNSENEVjNSahg/MdYscT -jNfeBZD1Q3lXcri9M+YiGYAqR+2cdYl1I+90Ydv5WMx5frrx2H3Y8UBONbOW/kKL -ddfXQ+yk9/JDeqpbHAbPWjNdPj6NK4du/DjmKN7pjF8zwoR/6nrpcMHL5N8wkDd1 -mKOrfV7u/ggWmO/tl7fQil9jomQQzLFZ+UDfSQth6piGHtHKTEDd1ATtBXfvH9fE -xboD+7hGSqSDbeXHGqbmITZv7g+ecwq9pAHFnZsMVgonnp7i+dEPC/T8KZSqqJyP -ezzNzXXQrsw2r1ROBynm20JYbg5sSaOm+h5aaza6tHHpAKgsedPII4PaFAMSmkx7 -FvkjVh28vY8ZqnAeBytwM38OACGR4dC3aY5ilCbmlF6lhMePZkRtvY0rvG8LxKI0 -/g/4GF6cCnl3r6kiF8vrtPzwdXGcS1w8w+klqudRnJfF1iv9sL4eTp0ylxA3/bQf -FHV0OAyp/TmlzOG3yK4sUJ8s+1kM/a1kijb6F6hUPKKg13dLMpUmrZ1Fd+iDNmz4 -G5MwIgZo7u0SIhxs/J1sy946uWjw+jyalcTPJib5gV7s/icqDrbVmDM2MnwBSbqy -uN5YfuAhQxRfe0zynrQxXr/OcGCRZAPR3d+aiCBnPEKjn18OqEKNOusKf7BQubTH -g4ltstg6/W6XfPlu8IPdf39/xDVhzmYUczvYjeUXUkSCOLJXU7aHvw+L8XnAb0VI -yizqlPV6kcs5nPV37S0JuCr0smvxNbGvs7dUhrqRDMUrcx4nvdJ03aZ5jaiS24Ma -7v/za8DwmRpaQEeeFTtbHez4GBmLzNS9Tmq+1Qlk72i3WB9iE4sMmGoxnLO1utxL -RztderhGBCccD23jWolV4jXfFGF1LvWxqXo35FKpgncPQc0B8YGxB0/06TfLb8tE -PMpjZw2hz3kTAGhLcLOhwT73VNHZhlV4xYB/1P6jluMqvTBZMDp/Ppv4edDDD+q/ -HDcV9oLJV0j2Ndam5ZHd/HH8/gC3hOIbLQZLujVlsqCcNmjeGWANTExg0rTRTsVi -d9MO0b4ipXT9xT80sn0GwWTngd+buTTQVWKQVr0FV4wY29qSHeKBrH27Jm4QloOf -HZzl64UlOEUq3/b5GP5vIRAiAA4tf2E479UknOAlIbqsP9yocwtBt2+d1r3N26RO -OP2Mb2ta9qXHshhIEMiOfNy72y3veWfZp3idkGf93qWntfhiglhHgsrpeut6a52R -2XijLX8yYKTmPYL7q+v5cdXNmRpVUhqar4tj/Foxt8Aj/sCNmFptnIX5xmeDpwr7 -PI2n71eo7edXyqiIVlgpWP6RXJURzS7K1Xb++ewyHcfqJGOnqkkf30PtiiyNlFR9 -7W7mjtD54DEXh+Zn6W32seyprUhy4lTLqMrPY2aMrWBs0LxeIw3n+SVwdkFp47W2 -2akz9sBiGHR5Zp55grTC4B09EM4NXhPdiZKRukiTH9WtrWPpeaYHSIERS/hVCo6G -b6qkLsgcWzhfoxXC+b8CyYc1YyaPr2DsQNfniJflU+OrMcD0aaWQkWL9G7iuinmO -S/w70zXVXJ7jy6vsX9DYJnvSUea0twEXlRxKZ1tP1kO5IWZI0gi127DBnVAtmffw -VTP3U6aHLa5vAMPIXy8af2U99J9gAzTSGp96Nch27HzDF3iLp+Uv9zh2xDdvtbcm -Yd8ET0OnNO0ME71JeZdM/pKd6UY3cDoRhJrmWN2rdXQB1IoHhWv78fzXk/Vov7RO -rfHrCQFvbiSOeK3BihZ+dJoyM5SoZWqAZiePyaLpJLLMDUmvI/qEehgoLkfd0mOZ -vL5ErC0Vdd+OXCtoMV+TSS26mtdMO4bTlg3To7Bh1tKDjA7wxMMwFLkWIsu8kqBp ------END RSA PRIVATE KEY----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/client.crt b/java/osx/osx-broker/src/test/resources/cert4test/client.crt deleted file mode 100644 index 614312fccd..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/client.crt +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIEnDCCAoQCAQEwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0 -MB4XDTIzMTEyMDA4MzQ0OVoXDTI0MTExOTA4MzQ0OVowFDESMBAGA1UEAwwJbG9j -YWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA6wxQajp/MMTr -zo1wNAUXPtf23ZAGFSQNG5D1sostCQYomVHHChaDQDbmREth8bNK10g3zBf7J14p -dJL3JZZtjX3dTzJHIPtSeFomCktbgKYzQjvgRdswPHJS/0Gz5VoohefnafKJecen -StMZ9T3XtKbHgc1z0uE+KVkr7PHPiigDX060a+UqAKdQi5wWTkWrVnaZxkF3PVze -pSe/0ixcGESmM98vEzxq8MU7EhRAeBRKOMMEovUBBkNi9IyM8egSjHknK3Qk/ulb -wPLbXhPSYvYhAMMc1jxTVDCh/WJcmXziHvwzgyfLQO4TfP+JgcE8+laCTOvMWhnR -AjnjYsxuKap+0CJsZUWCdp+81Ffzdum7nZo1UuneFB88Y4BTjTvJnZADlbIoffMN -NnbP9DD0JMEZEYSJqleeGdxonVij2Dvo//rtE+6HTuhjX1G7XDEkQ7+WOR6rNgFt -2l8IRmMGqt6yQObsIebXLJ92Tpe5uh/N75Zq7qX5uodosIEPclAy/7Eb9MAqRYrF -BUIgxAvyM2CoDWMYgozHYFZXVAyj5K4wkmpHItMMH5TYMh708ngO9SFXDnM6L03p -sKwKmeWCn2H9zhKE0LJ+HZUYtX04lXJ6GtiETfJBhdcWdTsLrL/aXRI5e+0UMJnp -rmmQ/n/zTNp36yxyy5d0176dZRNS/3sCAwEAATANBgkqhkiG9w0BAQsFAAOCAgEA -Pr7P2gXMZMUOw4TiIRY4LUBnmfaU0TrQn0cx3bEHf06EZqFoIUy3CzchGOE7QQ9n -Ev7Xo3OtayKILYZdEMOpiR+LdWJjq2LW1xaBfaS0CfcKSPCBHQlbWzrLOioUVm2i -01TY7ELXyCE/YaBY+YVhONq+wVnG/csMI9mT+M0mgy2ODEGKVf90F96Kmg7FeIUH -K4/cRvzA2Zn582XR36qV2gHn4vbX3Mf/GeFn1aQ4jJajGoDM1LOwCbNaHWMNkhv4 -CUaIqSV+qk9J6uC6j21AMsh2TQGG7pP6abxzbYOo0uvc3zsuZGcN9Csp+7Ulz6lx -IMcAh9vAriVlGUcBEfzZQ3g1U9/0N+rugNn7Mov7EWpsMcPUeHpccNeFrAo7eIT8 -pAstqIodyFSLD1RoTB6qsNYAo8xA1GpIWIFegUkXKZkumFpwErWYdHmzwePUAZwq -j8E/85oMeYdlJU9pSJXiZRtUBjGFhDxslqWiwglt3T8Ya6/wlfrsYU0IkgVkFuXR -LXjGTVg4yOkMKMx8458MCMpj78jAyMM77JjvZ47qlbjmOyeR7A2nQJSNCzGefUzV -YvMmmJEx6ORqqFRDGV66GaM6iQeRwrVDxMDD9S0DtNl9DfMpyn6QySQTBsPS2QP6 -/QTPdlHV1S8i9Hw81U11fmBe7ef3fYumESpD+f7AmKs= ------END CERTIFICATE----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/client.csr b/java/osx/osx-broker/src/test/resources/cert4test/client.csr deleted file mode 100644 index 0dbbdec07c..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/client.csr +++ /dev/null @@ -1,26 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIIEWTCCAkECAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0B -AQEFAAOCAg8AMIICCgKCAgEA6wxQajp/MMTrzo1wNAUXPtf23ZAGFSQNG5D1sost -CQYomVHHChaDQDbmREth8bNK10g3zBf7J14pdJL3JZZtjX3dTzJHIPtSeFomCktb -gKYzQjvgRdswPHJS/0Gz5VoohefnafKJecenStMZ9T3XtKbHgc1z0uE+KVkr7PHP -iigDX060a+UqAKdQi5wWTkWrVnaZxkF3PVzepSe/0ixcGESmM98vEzxq8MU7EhRA -eBRKOMMEovUBBkNi9IyM8egSjHknK3Qk/ulbwPLbXhPSYvYhAMMc1jxTVDCh/WJc -mXziHvwzgyfLQO4TfP+JgcE8+laCTOvMWhnRAjnjYsxuKap+0CJsZUWCdp+81Ffz -dum7nZo1UuneFB88Y4BTjTvJnZADlbIoffMNNnbP9DD0JMEZEYSJqleeGdxonVij -2Dvo//rtE+6HTuhjX1G7XDEkQ7+WOR6rNgFt2l8IRmMGqt6yQObsIebXLJ92Tpe5 -uh/N75Zq7qX5uodosIEPclAy/7Eb9MAqRYrFBUIgxAvyM2CoDWMYgozHYFZXVAyj -5K4wkmpHItMMH5TYMh708ngO9SFXDnM6L03psKwKmeWCn2H9zhKE0LJ+HZUYtX04 -lXJ6GtiETfJBhdcWdTsLrL/aXRI5e+0UMJnprmmQ/n/zTNp36yxyy5d0176dZRNS -/3sCAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQAjuzUSidXhop6qzMMy9PTf33bf -CGiL8YtC5ttKEuLWk5O/pUv6rb5O4jTpkp0WnkL7ylmlK1TsKW4qVsP1mz2P9+hG -qKh2dIuidte5nwPrdKgO7D4vAXkXmDk2u87+hhNFVvzLmNcfZ2lFvndoIFmpcMzm -6stCpQBXbEdoVmlrXujWBW5Zs6bup2Us2jD6v+Zs/fAWI2iStDMxm2fjpcJZ5u5I -g2ch5qumtiNi8DwBZsIC5rd7k7Dky/m6KfqGPWHXNqIyIKIDGmWtVFyYCC/py+Fn -f90c2s6oAQOlUbMv7TgnMZuFsJMGJuFyBTbzSL0l0lc1TnFKlYfDIan+x26dGxYC -Vf+z53Ww8CgbMqyc6mKnSkiL2WWvbeqalabLlyVJFNgV0/DBUJ8YyTZQJUtIKBD6 -RChSE/KjRn1KvNdBjAuLTMccUYRsN9ubNjm7IZgBNKNytzqiBPQqhwKvL+AhKrCJ -Aw8+v+fUFWkLBQDqs9HRI5UeXm+JF0emU87xi0m+MQ9xXNPCvHWUdQZm5roX1j4X -WBk9Dhsq7UZ4dnT2Jjes+BUOgtqb1xzfFWFpT4dPB46hzZrrSUqMzY7I4obh5d9i -W/c8b5mSzSBZ75RxZ/qtZmIvHUXkF763fztyTFhMLdR+qPDYK3Ltx973aj1LPva2 -hwv/6AwTLwX7MhqVSg== ------END CERTIFICATE REQUEST----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/client.key b/java/osx/osx-broker/src/test/resources/cert4test/client.key deleted file mode 100644 index 7be33b2c5c..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/client.key +++ /dev/null @@ -1,51 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIJKQIBAAKCAgEA6wxQajp/MMTrzo1wNAUXPtf23ZAGFSQNG5D1sostCQYomVHH -ChaDQDbmREth8bNK10g3zBf7J14pdJL3JZZtjX3dTzJHIPtSeFomCktbgKYzQjvg -RdswPHJS/0Gz5VoohefnafKJecenStMZ9T3XtKbHgc1z0uE+KVkr7PHPiigDX060 -a+UqAKdQi5wWTkWrVnaZxkF3PVzepSe/0ixcGESmM98vEzxq8MU7EhRAeBRKOMME -ovUBBkNi9IyM8egSjHknK3Qk/ulbwPLbXhPSYvYhAMMc1jxTVDCh/WJcmXziHvwz -gyfLQO4TfP+JgcE8+laCTOvMWhnRAjnjYsxuKap+0CJsZUWCdp+81Ffzdum7nZo1 -UuneFB88Y4BTjTvJnZADlbIoffMNNnbP9DD0JMEZEYSJqleeGdxonVij2Dvo//rt -E+6HTuhjX1G7XDEkQ7+WOR6rNgFt2l8IRmMGqt6yQObsIebXLJ92Tpe5uh/N75Zq -7qX5uodosIEPclAy/7Eb9MAqRYrFBUIgxAvyM2CoDWMYgozHYFZXVAyj5K4wkmpH -ItMMH5TYMh708ngO9SFXDnM6L03psKwKmeWCn2H9zhKE0LJ+HZUYtX04lXJ6GtiE -TfJBhdcWdTsLrL/aXRI5e+0UMJnprmmQ/n/zTNp36yxyy5d0176dZRNS/3sCAwEA -AQKCAgAe5reLv7UJDFqUBTRDIogz0uC5sD2ceejfPueOWY3KKe0cewvX363Ru2X6 -hI6T4CZutyfexShXvKFmmguz/VrZxzpZNxry0xe8it2FbPLSrwb+JjEN/gsRZ1ZS -CKlF9dxt/lcGLsS0JfNweuBmxYKeVW7VOdWIW+R4OyjzNbc7Spdm6EoABVjITTbh -o9uq3q3v6Be/YMv0XUlIHTmyv/I7norbNvRRaxgEH2nsrozrPH+lhr4NTnicAi/4 -RqIhC4mkvijQJazXdoaBj2wXqjN2nzUnjH82CyhJYTtqvIvAAhMYT7/V9l1aY/Jb -9Jx3WphRsR3gTv/GuK4pxMKIMqgowal2nvoIzv9Zgbt+JatHE/9SP6gqRVQ26OTJ -/vc/9xvwZfVPaTwoC+rou1t4d631KrBsWjV8DvPvdBiT3shYFgwLmcO/V5516PNP -lqqHt8mF0ExXh0+AW5EAOx0dPDDUsiwYpp0PGQ52iPQN64I/ldKgbHV13Lbz05ND -NI3IgbIeHj28rZASuqLtvDgDArFCgUR7tuoDQAYCw+PFOsQ/MM6HK5Z6hiXC9ZOC -FuOxQE+aEi9wKupEfn93eNGwRkP4pMOjtEqoeh4KnDR4MmHXrAZETsz8dVcKpYLk -Hcz2uwqN3hZgVoVPtxlYnm/TZZMBstuvdAwDoLlh8KRBFhK8gQKCAQEA9peWYqnm -zNtZWnVeQpXC9gq3nunskLlPKKU38fKFvjqHvBTcJag7fpWrluZFLHDg9s0gEmx1 -27ZAZ7EPf05Wwps2ubIntRHnMAriQg6sBy0gaBz/Muo1/luskbquxplU3Asc2rXM -UzbMx4yOY56Iyagq5OsL3L2fIhXhGb5wVpuQXd5xRIlFn28u0p+1EMUrHaO6aqUL -CVnXR1PmYFsvxJtbcG3BL7DLx9NHcQVn8wC/YZYVsOMsHnoABelhB3enLCdw+JII -W8x5YhbWLbAjuAHw5vgAVF49qfTrWqpEc+Y7JU114LGiew4JaYutuOvxyNihmViF -sYrwZ3L/ZMarqQKCAQEA9AP6gi05+9tkjDO8aEexmLNT1AC0oAW+J+vxNXF2TmD8 -k1jQ4uq1vGbD+Ti5JR/DahnRR0Yj83e0IvK94gDrTATDnXnhMLqmA6jQbyp6YB+e -LVZKNqhKFRh74dLuGBZP+owyd3tVofX1Inkdoa3/av/J8yq4piZ14iRB+qjOLCSh -8at7EML7UpDGiQEzRLw1lj2y5jdD32EXnt0gOiU7TIhhOSWAKeKLUqaPkfQNnBY4 -jF3IMu+htOeMv4dmRNbyPvr0hD9qyB6an+ZoBU+S1Oo/qT/n0Cz37+DWKCILmDdf -tSy5XUWbxqQIIFWaTZzjkg52CC8zN4heXd/yVFjogwKCAQBwG5yoQHwImJS39nIj -LXkUaOzwF4OQjF77qJmVqt+5C10YWhd4G1LpCtyW3xuFx8/PBJTXK24ttF71hV75 -TsFM+knYBLHetUP46InS5F67aH26N6yiFi7z8/Ox0UCSU7Vr0LWOjWZWUqyo8DLw -AWxI2eaeamnbMm49jdrn3FewWEs7Ed1G/m2jvWV5JlioRiuC7yPaRiyNVMX1zKQJ -HIvMA6F/rLZOmz8aGuj47i9DIAziLdywracqN+b4yRBu16wt+8R1jda0/XIV8THw -VYr3phJCv29O7AV21j2F27EBTCOJovy7aabn8QrAbFtPnh0vZaWaVM97VyJStcp8 -o4H5AoIBAQC9VMDwdGsDEg2IAzRyrP4Nf0bRveJoL0yF6Tn0v56N3g7bvRQGnRp5 -njr8ipiNR4H2NyX8aV3HsN9iJnpSe2gWSbQF4eVqS8g4GqnvN0RQhPfUMZnPovAo -QiEM7P60TcusmU8nCdk9m3uiTdtB8aG2wdVOCZ3PvRPGbV+MP6II+jt1KhqIvOEI -BTEmaHoBIQ9rDWBb5BGTpuAO8X/p3a2PClp1XrV8yjxT2syW8IgGze7+al3Ft8z+ -cpLwoPwm+ahoWYuTeSk/MQ/EdZ/MTxucfEz844rYKawOwaMo7JGWf5CRIKyKxFHD -5M4xWHorMkoYr4PBge35bqPZrsN10q5FAoIBAQCIPLqi/c3pjprFvw2109HiN+W3 -0YTQ/hp5bZLvPt9WxsSX9kQozLfxdTSa9wBjYWMivI4MyG0FsTVPVvNp2iJCC/iy -oUGgdylZ5K1gPsmbbnyVvUO5l/lhyziXTx2wS+3VEU6YVYen/R/qV9vetLuQRSZ3 -TBtpHOkax6C8kJkFN2WQVcGGlyvuLoKQ5QAIVF0AE5U0ywcIf7WESE20ysNsBHJY -uA3LUzPGJtlHu/Wks6sXisHGQNmeVefw2bhnmKipuvxM4w9WBRdOsar9+kN+pGNb -h6fTO432cbwFwRdBtrAsFYzjoP6ethybUfEbUzq30tNT+Coa2RFnroMz1AXv ------END RSA PRIVATE KEY----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/client.pem b/java/osx/osx-broker/src/test/resources/cert4test/client.pem deleted file mode 100644 index 376f4f5947..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/client.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQDrDFBqOn8wxOvO -jXA0BRc+1/bdkAYVJA0bkPWyiy0JBiiZUccKFoNANuZES2Hxs0rXSDfMF/snXil0 -kvcllm2Nfd1PMkcg+1J4WiYKS1uApjNCO+BF2zA8clL/QbPlWiiF5+dp8ol5x6dK -0xn1Pde0pseBzXPS4T4pWSvs8c+KKANfTrRr5SoAp1CLnBZORatWdpnGQXc9XN6l -J7/SLFwYRKYz3y8TPGrwxTsSFEB4FEo4wwSi9QEGQ2L0jIzx6BKMeScrdCT+6VvA -8tteE9Ji9iEAwxzWPFNUMKH9YlyZfOIe/DODJ8tA7hN8/4mBwTz6VoJM68xaGdEC -OeNizG4pqn7QImxlRYJ2n7zUV/N26budmjVS6d4UHzxjgFONO8mdkAOVsih98w02 -ds/0MPQkwRkRhImqV54Z3GidWKPYO+j/+u0T7odO6GNfUbtcMSRDv5Y5Hqs2AW3a -XwhGYwaq3rJA5uwh5tcsn3ZOl7m6H83vlmrupfm6h2iwgQ9yUDL/sRv0wCpFisUF -QiDEC/IzYKgNYxiCjMdgVldUDKPkrjCSakci0wwflNgyHvTyeA71IVcOczovTemw -rAqZ5YKfYf3OEoTQsn4dlRi1fTiVcnoa2IRN8kGF1xZ1Owusv9pdEjl77RQwmemu -aZD+f/NM2nfrLHLLl3TXvp1lE1L/ewIDAQABAoICAB7mt4u/tQkMWpQFNEMiiDPS -4LmwPZx56N8+545Zjcop7Rx7C9ffrdG7ZfqEjpPgJm63J97FKFe8oWaaC7P9WtnH -Olk3GvLTF7yK3YVs8tKvBv4mMQ3+CxFnVlIIqUX13G3+VwYuxLQl83B64GbFgp5V -btU51Yhb5Hg7KPM1tztKl2boSgAFWMhNNuGj26rere/oF79gy/RdSUgdObK/8jue -its29FFrGAQfaeyujOs8f6WGvg1OeJwCL/hGoiELiaS+KNAlrNd2hoGPbBeqM3af -NSeMfzYLKElhO2q8i8ACExhPv9X2XVpj8lv0nHdamFGxHeBO/8a4rinEwogyqCjB -qXae+gjO/1mBu34lq0cT/1I/qCpFVDbo5Mn+9z/3G/Bl9U9pPCgL6ui7W3h3rfUq -sGxaNXwO8+90GJPeyFgWDAuZw79XnnXo80+Wqoe3yYXQTFeHT4BbkQA7HR08MNSy -LBimnQ8ZDnaI9A3rgj+V0qBsdXXctvPTk0M0jciBsh4ePbytkBK6ou28OAMCsUKB -RHu26gNABgLD48U6xD8wzocrlnqGJcL1k4IW47FAT5oSL3Aq6kR+f3d40bBGQ/ik -w6O0Sqh6HgqcNHgyYdesBkROzPx1VwqlguQdzPa7Co3eFmBWhU+3GVieb9NlkwGy -2690DAOguWHwpEEWEryBAoIBAQD2l5ZiqebM21ladV5ClcL2Cree6eyQuU8opTfx -8oW+Ooe8FNwlqDt+lauW5kUscOD2zSASbHXbtkBnsQ9/TlbCmza5sie1EecwCuJC -DqwHLSBoHP8y6jX+W6yRuq7GmVTcCxzatcxTNszHjI5jnojJqCrk6wvcvZ8iFeEZ -vnBWm5Bd3nFEiUWfby7Sn7UQxSsdo7pqpQsJWddHU+ZgWy/Em1twbcEvsMvH00dx -BWfzAL9hlhWw4yweegAF6WEHd6csJ3D4kghbzHliFtYtsCO4AfDm+ABUXj2p9Ota -qkRz5jslTXXgsaJ7Dglpi6246/HI2KGZWIWxivBncv9kxqupAoIBAQD0A/qCLTn7 -22SMM7xoR7GYs1PUALSgBb4n6/E1cXZOYPyTWNDi6rW8ZsP5OLklH8NqGdFHRiPz -d7Qi8r3iAOtMBMOdeeEwuqYDqNBvKnpgH54tVko2qEoVGHvh0u4YFk/6jDJ3e1Wh -9fUieR2hrf9q/8nzKrimJnXiJEH6qM4sJKHxq3sQwvtSkMaJATNEvDWWPbLmN0Pf -YRee3SA6JTtMiGE5JYAp4otSpo+R9A2cFjiMXcgy76G054y/h2ZE1vI++vSEP2rI -Hpqf5mgFT5LU6j+pP+fQLPfv4NYoIguYN1+1LLldRZvGpAggVZpNnOOSDnYILzM3 -iF5d3/JUWOiDAoIBAHAbnKhAfAiYlLf2ciMteRRo7PAXg5CMXvuomZWq37kLXRha -F3gbUukK3JbfG4XHz88ElNcrbi20XvWFXvlOwUz6SdgEsd61Q/joidLkXrtofbo3 -rKIWLvPz87HRQJJTtWvQtY6NZlZSrKjwMvABbEjZ5p5qadsybj2N2ufcV7BYSzsR -3Ub+baO9ZXkmWKhGK4LvI9pGLI1UxfXMpAkci8wDoX+stk6bPxoa6PjuL0MgDOIt -3LCtpyo35vjJEG7XrC37xHWN1rT9chXxMfBVivemEkK/b07sBXbWPYXbsQFMI4mi -/LtppufxCsBsW0+eHS9lpZpUz3tXIlK1ynyjgfkCggEBAL1UwPB0awMSDYgDNHKs -/g1/RtG94mgvTIXpOfS/no3eDtu9FAadGnmeOvyKmI1HgfY3JfxpXcew32ImelJ7 -aBZJtAXh5WpLyDgaqe83RFCE99Qxmc+i8ChCIQzs/rRNy6yZTycJ2T2be6JN20Hx -obbB1U4Jnc+9E8ZtX4w/ogj6O3UqGoi84QgFMSZoegEhD2sNYFvkEZOm4A7xf+nd -rY8KWnVetXzKPFPazJbwiAbN7v5qXcW3zP5ykvCg/Cb5qGhZi5N5KT8xD8R1n8xP -G5x8TPzjitgprA7BoyjskZZ/kJEgrIrEUcPkzjFYeisyShivg8GB7fluo9muw3XS -rkUCggEBAIg8uqL9zemOmsW/DbXT0eI35bfRhND+Gnltku8+31bGxJf2RCjMt/F1 -NJr3AGNhYyK8jgzIbQWxNU9W82naIkIL+LKhQaB3KVnkrWA+yZtufJW9Q7mX+WHL -OJdPHbBL7dURTphVh6f9H+pX2960u5BFJndMG2kc6RrHoLyQmQU3ZZBVwYaXK+4u -gpDlAAhUXQATlTTLBwh/tYRITbTKw2wEcli4DctTM8Ym2Ue79aSzqxeKwcZA2Z5V -5/DZuGeYqKm6/EzjD1YFF06xqv36Q36kY1uHp9M7jfZxvAXBF0G2sCwVjOOg/p62 -HJtR8RtTOrfS01P4KhrZEWeugzPUBe8= ------END PRIVATE KEY----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/server.crt b/java/osx/osx-broker/src/test/resources/cert4test/server.crt deleted file mode 100644 index 1d99bf4e11..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/server.crt +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIEnDCCAoQCAQEwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0 -MB4XDTIzMTEyMDA4MzIyNFoXDTI0MTExOTA4MzIyNFowFDESMBAGA1UEAwwJbG9j -YWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuh+VBZJHrrd9 -Bph0epqT1WcfRDa5kBQI4xwrwYLpdSraBX0shW8dHI5lx5dDeae3HOmOY3eruIGC -oOUM3Fe5Anv85aXbWABdbpIWFFY1Mxhp64YGwGT+pDI3okWvGhIsQic8Qz9kTBGI -IJPNz7sponN/FJmJByCwwbBIKwtvalsXQrUK/8QjJu7KTNE39J8wG9ci5nVncMjV -cGbHsGOqLrWt4qI5mKigB/fz0Xa8FfTQC0gtgKnMEi+ntmxw/2zu5wxNDmFJu4/g -ZZQy1ilf5bZQQND9Hq/PPDHzpGMcK7PnfiJKttRTRxYSgN7P0Oavj7oJDe6ZGGqo -cdugKkOR4yvu/rKHJs06KBi3jE4ERs6z579BaAq2LdB6IOeSxmZEkxykIuGfoEfv -693CAvSP42JesZucN+Nj/eSdPe6r7o2hPYtVs+JpA21CBnBRwtIxo6X30Jp/B4at -SOHZIH2kpcaoH9/c94D3pEBSHxfCDP4kqKrwA12vprVufM1vxMbNhNZc3vX17tXY -tDpTZN5ltzKYnCvcDDnoj98Voqz1D6rh5p1PvsyF6wAebvPlyfDUyVC0E5vDovQN -qjo/ZhWgAeoUH2/v4juM7barPl5nLh7xJzeGBWAmo9l7n4TvDaR8nVBzGg9M42UD -SIBJpmQsOoVvvwtGFMQn64HqcYQUXpkCAwEAATANBgkqhkiG9w0BAQsFAAOCAgEA -sutvlAPQgxz4bCBUgIuvgyIIoEUKcy4Cc89wKsNOFkkvru9eNWhqIMzWTYeFD65T -wHIaYLcwjCx83gmdaRLtT4jYNvWAiTtCP/cDsGUTGGkCojSYr8kMGR+oILk4fSOm -xE2RhTds6lL8bMMeUxG/OKtmRLw0BOQ7ipmI2DG15mLcjR2J4wtx8eSkpZc+AWZA -lJibgCJnb6rQpiAbPD3ShXm4EOPceEjC80b+V1wbkem0xYScdGKIoUtg7YWQD3Gd -zDTYKzltH4dc3h3ZdBTkUfnZE3oJ29xY2rZjufjZPElCuolMo+QUMWZBCBhi5UUS -PViLaVcQUE77Qfaz0zLmw/dKE/FzTt44j8FQhdUBLtpKP68di1RDjy+/2rK84f2h -bmyn4NizKNvngWWWUChvjqqkHchfxTy+gTCeLEjbdUD95RZIDWUYsqFanv4FJ+ZX -XJVcsw7qnHCmvavCo2lsdRHhzVLl5GH63F73M3nsrA3ySYhJJMht+YEM8w4xISiy -p5kPeuED38IZ5Cm/vLv/ORJ/r018PTy4j8UlZczJn2juWD21d5fJYY/oKCAULX3Y -sljgsMIadp0F/O/1Muck0ODSMO8kVLQfIeekJAZUtZ/IGs7s/dhwVO3vTvZ9plRA -DWjP3FDpgPfRcZtY8VR6WJ3tT3q1/MjV96+V+77W8kU= ------END CERTIFICATE----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/server.csr b/java/osx/osx-broker/src/test/resources/cert4test/server.csr deleted file mode 100644 index 934892ffaa..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/server.csr +++ /dev/null @@ -1,26 +0,0 @@ ------BEGIN CERTIFICATE REQUEST----- -MIIEWTCCAkECAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0B -AQEFAAOCAg8AMIICCgKCAgEAuh+VBZJHrrd9Bph0epqT1WcfRDa5kBQI4xwrwYLp -dSraBX0shW8dHI5lx5dDeae3HOmOY3eruIGCoOUM3Fe5Anv85aXbWABdbpIWFFY1 -Mxhp64YGwGT+pDI3okWvGhIsQic8Qz9kTBGIIJPNz7sponN/FJmJByCwwbBIKwtv -alsXQrUK/8QjJu7KTNE39J8wG9ci5nVncMjVcGbHsGOqLrWt4qI5mKigB/fz0Xa8 -FfTQC0gtgKnMEi+ntmxw/2zu5wxNDmFJu4/gZZQy1ilf5bZQQND9Hq/PPDHzpGMc -K7PnfiJKttRTRxYSgN7P0Oavj7oJDe6ZGGqocdugKkOR4yvu/rKHJs06KBi3jE4E -Rs6z579BaAq2LdB6IOeSxmZEkxykIuGfoEfv693CAvSP42JesZucN+Nj/eSdPe6r -7o2hPYtVs+JpA21CBnBRwtIxo6X30Jp/B4atSOHZIH2kpcaoH9/c94D3pEBSHxfC -DP4kqKrwA12vprVufM1vxMbNhNZc3vX17tXYtDpTZN5ltzKYnCvcDDnoj98Voqz1 -D6rh5p1PvsyF6wAebvPlyfDUyVC0E5vDovQNqjo/ZhWgAeoUH2/v4juM7barPl5n -Lh7xJzeGBWAmo9l7n4TvDaR8nVBzGg9M42UDSIBJpmQsOoVvvwtGFMQn64HqcYQU -XpkCAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQCJuIw6mIh38H+C/q7f8i95rEpM -3x45djsh+7II9+2LhEEGO8CVRMi2rGO5GVpdseqf3GufEiMhhYh6nhID2XUPGNGk -Dz8Neqp3qZC+v4g3kS+64g6cVczQi3N+Sh+gl/vM8viNV+6wmdPRVA5BdR0mJZFs -zyUlpqAIObn8EOOLoXrZ/dEgKJLb7jitx9zxTwXeEOzNB3Gpq1FiV5QzYhoe1cXp -hyAAF1uyZvA8sdf2f1bzAMxaqlMsp5DboJATc0o4ZsH6ut1ZvbROyPZDFgkNdSE6 -cPlP/o6xl2FPh2JfNQQKBKuQxqkpWo01BJMePPUWrlWCdPmsvnwRVhCi324zs7Qx -EdvXCjIvSbFZDIdzNRJUs1TqJGmKQBGb+pZ7LpoLAjoUj4fYaSv4iYXsz68Mgoya -28QwTGxcQfx/tDj7zZr6xGZBB+yDSHiHImNEx1LxapMbHfMimVSZQJDmE4iH9QNo -1zCbTHTlcCjkjB3DwzjeCC7Yk6pEsUAzEDols8hDtQOpiEZ+yVpD03TXnRxuBmoe -ne8zWgkTYSPFri2T156Tvquw9oPYzPXyqLwKHZejJUPO1RrkUXLFQtw9QiaMHzhE -pd7Qp+mR4BNDklC1dc3qf60qNwwrY743KFonuDRJC+rpjs04CKaN/QOSNFEADZZv -u7s4bp6eF5WwDoWfVA== ------END CERTIFICATE REQUEST----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/server.key b/java/osx/osx-broker/src/test/resources/cert4test/server.key deleted file mode 100644 index e7a9673268..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/server.key +++ /dev/null @@ -1,51 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIJKQIBAAKCAgEAuh+VBZJHrrd9Bph0epqT1WcfRDa5kBQI4xwrwYLpdSraBX0s -hW8dHI5lx5dDeae3HOmOY3eruIGCoOUM3Fe5Anv85aXbWABdbpIWFFY1Mxhp64YG -wGT+pDI3okWvGhIsQic8Qz9kTBGIIJPNz7sponN/FJmJByCwwbBIKwtvalsXQrUK -/8QjJu7KTNE39J8wG9ci5nVncMjVcGbHsGOqLrWt4qI5mKigB/fz0Xa8FfTQC0gt -gKnMEi+ntmxw/2zu5wxNDmFJu4/gZZQy1ilf5bZQQND9Hq/PPDHzpGMcK7PnfiJK -ttRTRxYSgN7P0Oavj7oJDe6ZGGqocdugKkOR4yvu/rKHJs06KBi3jE4ERs6z579B -aAq2LdB6IOeSxmZEkxykIuGfoEfv693CAvSP42JesZucN+Nj/eSdPe6r7o2hPYtV -s+JpA21CBnBRwtIxo6X30Jp/B4atSOHZIH2kpcaoH9/c94D3pEBSHxfCDP4kqKrw -A12vprVufM1vxMbNhNZc3vX17tXYtDpTZN5ltzKYnCvcDDnoj98Voqz1D6rh5p1P -vsyF6wAebvPlyfDUyVC0E5vDovQNqjo/ZhWgAeoUH2/v4juM7barPl5nLh7xJzeG -BWAmo9l7n4TvDaR8nVBzGg9M42UDSIBJpmQsOoVvvwtGFMQn64HqcYQUXpkCAwEA -AQKCAgEAlIlL1oxtJKRO5QqaOpZOUMrhiwDZioBSr1z2FpMxWU5/fE3vT/XjF70U -wPqY4OfWHP7PodYJd0/0Pg9N+jMP9UmaBHQe3tY7ulhfwo8iGcrsDQiDLtvq1IM3 -HwvZuEa0h37kew6GLqb3KniKkbPegEUIMBpv3v5Z0dmrXp7bpddYcuYlBwUywIll -bXSy4UiBjlZdBerASKQeonuD9eM0F97qDKpGqOw5+uII9St641LjDX3mwn2/3Oun -PtDARThcWIvamxVNUKB8BtUE1SFj5OFgnrmqp+jKzFFZDeICw2Xp4yHe2pYMd/jH -f41R8HeJuHxRaYr3JyNHlsYdxlzyeSqhtvEIyZCgrbvIUMtWnfkEpalY806ADMLW -zmDOcI8tCHSCRM1Y2NTAwbUwG1YKtTxaHmCXuyFD3N1nJK6qv0gtilT6ZpJMwjJ/ -U2V8wuNRAHiS7AqXVD6QvgxtJqSieDlAb34q+7kuQMfkfCc+ePb3NpWIqaBoAZDI -SB9Rao7lk0KrxKCYdxtG/iGy9O8rUFABmqr5bURPoIXi4r4xT7yBpAVJp+g0zS3Z -+Vah20XeVMiYp9LUKVKJFtTNSOek+/olSxQC1/TERY/5MlJl8ShkvveqA/5j2hz2 -terVcPDko3xvO9BRiTgyh3e5OOe/PRb3YIV9fbNxjtuvF4gRC2ECggEBAOiyBHYW -FyDp6YNWbHBTRDL+NHuG/ywOc998PanUnwo19O4DdbrSLWkiDMVl8F4aT4hc3iV5 -WR5GtHJXobnlS68IpuFBmq3UUgucyesk+BJtiNZUq6JaSrD3lBQIDmKaYc866VUU -3FuGpCwWm8U6J6MzMaFXPSNpSPZqovkYIAwjZcVVC6O+RyVH2YeI0Wf8E/ZTpHf6 -dHcpHxOYp1uly6UuSfqvNGiNwhfa5jc1Ctsx7Q2hIw+/lu5Idfh4tovD2IHfV8Xt -z4Zh8a06wLF5ZBXjpm/GjFSZBVTEYspIeZ7yqzveVlkUUccHyHHP4lcornpOT/QO -sbjQUcYjyABV+m0CggEBAMzDhhKVe/mZSluXyPzVhxQvBAa84hVj62H95DjVl6oR -7bFtwXxfXf3goCFx7fcZWygUeJf8sXrUDXEJ1hdaIcD/vGPJqa6oHMiSjxbHvre4 -Qr5kOb6daABRWmdipudvbIZi5rocUGs4O3rT86Vp7gqF84i+cBice5mDUp8hRjjC -hnq//rmi0emHf6UNNW0IhUEOiSi3T8gClNJnAGwfTGmfwa7ni9wNHydosEDmG73Z -AHfUzg1NnOGWSYAQmywWviZi8LbLFjRR7Lwosdq90uSeEJVE/aOj+bsrbBAGvVc0 -CGoRSwCLZfpL+EwbD6EYjQB3v6wGAPr1Ygu+vJaa2V0CggEAYbrqnsH2Ys97ULsK -fj6qhRQ47KytHU7Qoctnhp2TUlGJFjIDzzwY8G/plzqSMqOwRaBjeK+3mzys6t0d -QpsoJ1Jl7HOGSH3FG7V1JLp5Khww/XvAPkbX2e2RlrwvdoBKliOy/hXt2s74wr+Y -GsSrAyMsNAwU1HuStlPhMOdOBmsTgkaOxe3Tqbe8h+0Rri+0Hp/QksdxBN8Te0KA -/7/pgO2pCo1tYIAxRZ2dVRCFB7y3SCMmO6YG+Psb+QiR+q99jkZEcg/IOjOGsm/b -oG5Qd9UOASINrDY8g/abW0QHOJfJDTL7ZxxeoE4HhK1/7YVbimi7sdA+GlX8ElDS -3jk+HQKCAQAG2HaQAn1dj9ljjISEp2LXsuawjvoD+w4wfXt2xvVGE0leCCxSyyFq -TmssExIAk9FvWpfZhPIuCA7W+5wztaixOhuDPoe0thwYPIYaHd1raPaaROGFVN/Y -OuAJ2st3q0r2hzHtgOrTWtLqPVHE8vCpW2cT6EH5IKolLLXDaipd1WsHiYmrjRz1 -cLk8vF9P4NwLm7/MI6zAJA3zpsvl0XoNgfDItyb+2VV0TNSvpsHArBOO7gdhfHnF -NPAKHwQBClWbFO21Pr7kSuTeOYIQrQ3y5LHrO547LU05C6+WLZOA6dVqLl/SidaD -8qw/Zxwzp413OYmn5596xF9dwyes5UplAoIBAQDSPZgpXFGoLiOrrCtasqXb4+bh -H3Hv2BD8d+ZbhBLFyzBliJwnHhp55iC7yk5jq6BEwtrOMEqFVXjmSRyU4XLfZaql -31jX+JCt061nfOkvNctVZxkkTvGn6N5z4UZdyu2KsqIQrHnnbguwcsMlb8VHuMab -mBhH9RYjS9bNIQJUA/uROQawAcm0jW3Il9hmX2gAQfKn5eFP/aKdyYPIQHI8EeaS -oxOu6v+BFEB7jKEZfmOlrNZ8Y9wRWH/i5AACvd1gdzxmHJjLd5TV3gixDQHnsEPH -GTi8vTkxa31yo3ZZahCqzfcBwTAl1jJ6gbHrYPQOzZXGxDVIRpsN+UeN4Yca ------END RSA PRIVATE KEY----- diff --git a/java/osx/osx-broker/src/test/resources/cert4test/server.pem b/java/osx/osx-broker/src/test/resources/cert4test/server.pem deleted file mode 100644 index cbff662351..0000000000 --- a/java/osx/osx-broker/src/test/resources/cert4test/server.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQC6H5UFkkeut30G -mHR6mpPVZx9ENrmQFAjjHCvBgul1KtoFfSyFbx0cjmXHl0N5p7cc6Y5jd6u4gYKg -5QzcV7kCe/zlpdtYAF1ukhYUVjUzGGnrhgbAZP6kMjeiRa8aEixCJzxDP2RMEYgg -k83Puymic38UmYkHILDBsEgrC29qWxdCtQr/xCMm7spM0Tf0nzAb1yLmdWdwyNVw -ZsewY6outa3iojmYqKAH9/PRdrwV9NALSC2AqcwSL6e2bHD/bO7nDE0OYUm7j+Bl -lDLWKV/ltlBA0P0er888MfOkYxwrs+d+Ikq21FNHFhKA3s/Q5q+PugkN7pkYaqhx -26AqQ5HjK+7+socmzTooGLeMTgRGzrPnv0FoCrYt0Hog55LGZkSTHKQi4Z+gR+/r -3cIC9I/jYl6xm5w342P95J097qvujaE9i1Wz4mkDbUIGcFHC0jGjpffQmn8Hhq1I -4dkgfaSlxqgf39z3gPekQFIfF8IM/iSoqvADXa+mtW58zW/Exs2E1lze9fXu1di0 -OlNk3mW3MpicK9wMOeiP3xWirPUPquHmnU++zIXrAB5u8+XJ8NTJULQTm8Oi9A2q -Oj9mFaAB6hQfb+/iO4zttqs+XmcuHvEnN4YFYCaj2XufhO8NpHydUHMaD0zjZQNI -gEmmZCw6hW+/C0YUxCfrgepxhBRemQIDAQABAoICAQCUiUvWjG0kpE7lCpo6lk5Q -yuGLANmKgFKvXPYWkzFZTn98Te9P9eMXvRTA+pjg59Yc/s+h1gl3T/Q+D036Mw/1 -SZoEdB7e1ju6WF/CjyIZyuwNCIMu2+rUgzcfC9m4RrSHfuR7DoYupvcqeIqRs96A -RQgwGm/e/lnR2atentul11hy5iUHBTLAiWVtdLLhSIGOVl0F6sBIpB6ie4P14zQX -3uoMqkao7Dn64gj1K3rjUuMNfebCfb/c66c+0MBFOFxYi9qbFU1QoHwG1QTVIWPk -4WCeuaqn6MrMUVkN4gLDZenjId7algx3+Md/jVHwd4m4fFFpivcnI0eWxh3GXPJ5 -KqG28QjJkKCtu8hQy1ad+QSlqVjzToAMwtbOYM5wjy0IdIJEzVjY1MDBtTAbVgq1 -PFoeYJe7IUPc3Wckrqq/SC2KVPpmkkzCMn9TZXzC41EAeJLsCpdUPpC+DG0mpKJ4 -OUBvfir7uS5Ax+R8Jz549vc2lYipoGgBkMhIH1FqjuWTQqvEoJh3G0b+IbL07ytQ -UAGaqvltRE+gheLivjFPvIGkBUmn6DTNLdn5VqHbRd5UyJin0tQpUokW1M1I56T7 -+iVLFALX9MRFj/kyUmXxKGS+96oD/mPaHPa16tVw8OSjfG870FGJODKHd7k45789 -FvdghX19s3GO268XiBELYQKCAQEA6LIEdhYXIOnpg1ZscFNEMv40e4b/LA5z33w9 -qdSfCjX07gN1utItaSIMxWXwXhpPiFzeJXlZHka0clehueVLrwim4UGardRSC5zJ -6yT4Em2I1lSrolpKsPeUFAgOYpphzzrpVRTcW4akLBabxTonozMxoVc9I2lI9mqi -+RggDCNlxVULo75HJUfZh4jRZ/wT9lOkd/p0dykfE5inW6XLpS5J+q80aI3CF9rm -NzUK2zHtDaEjD7+W7kh1+Hi2i8PYgd9Xxe3PhmHxrTrAsXlkFeOmb8aMVJkFVMRi -ykh5nvKrO95WWRRRxwfIcc/iVyiuek5P9A6xuNBRxiPIAFX6bQKCAQEAzMOGEpV7 -+ZlKW5fI/NWHFC8EBrziFWPrYf3kONWXqhHtsW3BfF9d/eCgIXHt9xlbKBR4l/yx -etQNcQnWF1ohwP+8Y8mprqgcyJKPFse+t7hCvmQ5vp1oAFFaZ2Km529shmLmuhxQ -azg7etPzpWnuCoXziL5wGJx7mYNSnyFGOMKGer/+uaLR6Yd/pQ01bQiFQQ6JKLdP -yAKU0mcAbB9MaZ/BrueL3A0fJ2iwQOYbvdkAd9TODU2c4ZZJgBCbLBa+JmLwtssW -NFHsvCix2r3S5J4QlUT9o6P5uytsEAa9VzQIahFLAItl+kv4TBsPoRiNAHe/rAYA -+vViC768lprZXQKCAQBhuuqewfZiz3tQuwp+PqqFFDjsrK0dTtChy2eGnZNSUYkW -MgPPPBjwb+mXOpIyo7BFoGN4r7ebPKzq3R1CmygnUmXsc4ZIfcUbtXUkunkqHDD9 -e8A+RtfZ7ZGWvC92gEqWI7L+Fe3azvjCv5gaxKsDIyw0DBTUe5K2U+Ew504GaxOC -Ro7F7dOpt7yH7RGuL7Qen9CSx3EE3xN7QoD/v+mA7akKjW1ggDFFnZ1VEIUHvLdI -IyY7pgb4+xv5CJH6r32ORkRyD8g6M4ayb9ugblB31Q4BIg2sNjyD9ptbRAc4l8kN -MvtnHF6gTgeErX/thVuKaLux0D4aVfwSUNLeOT4dAoIBAAbYdpACfV2P2WOMhISn -Ytey5rCO+gP7DjB9e3bG9UYTSV4ILFLLIWpOaywTEgCT0W9al9mE8i4IDtb7nDO1 -qLE6G4M+h7S2HBg8hhod3Wto9ppE4YVU39g64Anay3erSvaHMe2A6tNa0uo9UcTy -8KlbZxPoQfkgqiUstcNqKl3VaweJiauNHPVwuTy8X0/g3Aubv8wjrMAkDfOmy+XR -eg2B8Mi3Jv7ZVXRM1K+mwcCsE47uB2F8ecU08AofBAEKVZsU7bU+vuRK5N45ghCt -DfLkses7njstTTkLr5Ytk4Dp1WouX9KJ1oPyrD9nHDOnjXc5iafnn3rEX13DJ6zl -SmUCggEBANI9mClcUaguI6usK1qypdvj5uEfce/YEPx35luEEsXLMGWInCceGnnm -ILvKTmOroETC2s4wSoVVeOZJHJThct9lqqXfWNf4kK3TrWd86S81y1VnGSRO8afo -3nPhRl3K7YqyohCseeduC7BywyVvxUe4xpuYGEf1FiNL1s0hAlQD+5E5BrABybSN -bciX2GZfaABB8qfl4U/9op3Jg8hAcjwR5pKjE67q/4EUQHuMoRl+Y6Ws1nxj3BFY -f+LkAAK93WB3PGYcmMt3lNXeCLENAeewQ8cZOLy9OTFrfXKjdllqEKrN9wHBMCXW -MnqBsetg9A7NlcbENUhGmw35R43hhxo= ------END PRIVATE KEY----- diff --git a/java/osx/osx-broker/src/test/resources/keystore/client/client.cer b/java/osx/osx-broker/src/test/resources/keystore/client/client.cer deleted file mode 100644 index 9b97f3d6a5..0000000000 --- a/java/osx/osx-broker/src/test/resources/keystore/client/client.cer +++ /dev/null @@ -1,20 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDVzCCAj+gAwIBAgIEcOge8TANBgkqhkiG9w0BAQsFADBGMQswCQYDVQQGEwJO -TDERMA8GA1UEChMIQWx0aW5kYWcxETAPBgNVBAsTCEFsdGluZGFnMREwDwYDVQQD -EwhTdWxleW1hbjAeFw0yMzExMTUwOTQ2MDZaFw0zMzExMTIwOTQ2MDZaMEYxCzAJ -BgNVBAYTAk5MMREwDwYDVQQKEwhBbHRpbmRhZzERMA8GA1UECxMIQWx0aW5kYWcx -ETAPBgNVBAMTCFN1bGV5bWFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC -AQEApYEjB7zAQzOUyZwszGkKg95Z22J1ZkthvGXQigbQ4QnjRtZJ3HGU7L6S3iDx -hDafZTqh7YoI7B7O9ZpXef3n0fCgQy55C3bxL7F4XBRhoSMLoM8kWyyREpy7DU9p -N18+pY27bQ1QnRL/qBnluUNiZO3h5M5ldsJaPpMGApSwvpGBkAZbwbCPimlWwACS -Kol4bH7xh6vlnhloqQaXgQzRNjRGwQdfZ0nlU0uL5fwYZ2zz2+XN1RLitPmxq18f -yjuNQ6AeBTMClHJDhWxdxF9b04jHwpEAwFIRGUVHC2OMxa8iBI0NdLEf5U7rk116 -CJUHoAzObHSXa/TsMMTKFeBWWwIDAQABo00wSzAdBgNVHSUEFjAUBggrBgEFBQcD -AQYIKwYBBQUHAwIwCwYDVR0PBAQDAgO4MB0GA1UdDgQWBBSKdFK+RE25B/Ll6lCF -ZaW7l0RGEjANBgkqhkiG9w0BAQsFAAOCAQEAH/WnEXr0zk5T+t2jtQ8eZq9iRnPt -qNft9br5gOU9jxGLVtCLKjpOg+EIir5iV7Vj6nCgo/HHbjQauJyiRrvJV0wlKCXa -Vaqp+mtS1DA3/rsj7YoN4jI8YO/VSyIckUTSJ877mCJ0JouNdqJrppeG5FER7KcX -wjIcYLmWXfJJ0sabr7MZ8vGcbLKC1AVq9vZLTAaxy36SppPbh01AD+sEqGIOYtJK -rkhWhChG6iF3YuGGv4qQomdPzKjsz/9aW5LflvcFp1w/HUyes3fuXfoPjYv5U0sZ -FCJrUKo26PBEwogZQj8A05TS7LuHXnnsTa/EZ4mI3toWBAdZ91KuLFbBwQ== ------END CERTIFICATE----- diff --git a/java/osx/osx-broker/src/test/resources/keystore/client/identity.jks b/java/osx/osx-broker/src/test/resources/keystore/client/identity.jks deleted file mode 100644 index e089f8ae9f11257714e8637a839b597878f1c657..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2551 zcmY+EX*3iH8^l60X8v8bkIyLLpo` zQpfQK+gz8gNm6z7qQ0JeGYnN`W!R} z4?u(7oZ(a`8tD122zU>G23ntCi?cCfV*B3}J1ady4jS+aiU#aM<(XLjkI$Y90U(Hh zcBkhxKG}fL>HeQxCQj<2JWLFXbgj1mXh3HhtAunf@zT+|4{FG=(HkZo0$X&^e7mC^l670n(ksuY|RlRyv|3DklknGo@;BUNE?BcQWCwU$@oz8v} zJ`1oXt{ARfbwY$e72xw_mJS1vlKz_{;q8HyHpg*34WO@ zE0{=-?+aP!fw(5*4sR(pwAq#uPWq-t4a{|HMk-t9?q832bJ;8nULCAjYBAW4?Q`9s zX1y$g0<*&ZJPDL-=8+X(95lE5#aZ#1Py&e3(f+jkEN98d4zbklaevp$mC@2p8+$qL z<9*yIYmtmT@tYI`h-w$#z;%tXWo?t-T#k8dicXxvkwy9Wkbjgn{=QpCv8mY&xWh=;!f?9 zdt*sC?9oeTtT9oH*O zD%db!Sgssb@RC^LNj~MDxn9JjqrOF#@L>hUBJuGtcf!rDGJ3ijqKxJ)Y6qNi=$Q)8 zRK%+etiSQmZ!cp|*v6|QJm&lJnm_(kWl7aq@5Ukfz{#vF(W#cf44giC-?)J}e>*n3 z1H=>7xAXl?rRaWbB$~tE-4~*fHEwzQ`Y;8N`71>`0k%G9FXr>}i{2A8WWhmR?C^kV zL}~gO=IS!mnoVd#$ugj+u2PnoggeJ&wKNZwVm^Yhl%*9I3UbRM!Yro6Lbp8z{N8X( zHs}_>JTdt^;!1-rofPj4_xjzO$rGr&nBsapGsIA6qNv7jPqS{!1eVQYgGW<8Leffp zZg$Slp05~{`aTP}CTp+s8aBKyp<%FO99yg4n0{YBx%xr{uqNw6#4wIHLS5jg$C%ef zi>3vUY3(xwW<~LH@OK9$#s0ZNfxCL9Lfk|Z3y$7yTHV9YOk&w14-C8hyGsscp|N6= zB$j^E2&gC-1R`D0+|1tENYwx}lQi`*_iA$oJ8D#Ig1u~tefK)dHJ$pNhR<;gk6d?e zAgWO(_(w4MZ)`Jb5|E9XH;mI4YPU9b8iA0lx{Rg_(!%b7mqluxd%yJ`@t_vf3p1a@ z*kFFdh?X)%wWmYIHhxN0{Kh_g8dkg6vkukaoaM-g8rv}*gUDvYj0+_!>uJtd z={0p5oeU?NL(2ClxejC!PqKIIsI4b3CVZfKwdzV@NMKDZtI+`TZE-Wi>zQO~evgq$ zIt9SVzBNHg0Vl7 zzbflNh5lA02%^j>%1-A@=S}BH=Sug8&L1lAPspnb0f4PMeO$$rHK9ZW(QfbHD(jyPc6TJsMC$6&$Z@Ds_P*wH}_5!ZxGEce0zM!tg$BmALz zBz9xH2iLd~LfD(pObrEYfTH2{o z2e|~LaPDF9jji?7Zno_m=c8uW`JVjMyDSb*u5rNHZVgv4tV^wglFW0MWe9cA?1nph z;7PBfJ&AGHe3MC79=MHOn%jC^wz!NuKY_69$39IMY(G{IHd%e7v?fGQ3HowkUWtw; z7g?|pQA*+hRwB|?y2zfl@}_QYMY&S;fd~9AgxeMD5{otnvbSdabzbX)AaIq!X?HW0 zOXMri1(S&=D^wIeg6tto>!$R$q1WU(`;(Z|o8VBFXxoGQFQ6`L#LVfkpM%mZTjd|p z?rM#Pi@~SAB=v_TZRy}6ZQSMUcti|Rc64Xe)u329$_tVDZ0;%57Q@C_Z}X$#)kjRA z?>&N!ycuX$=R~RS#IVXpQT2otrontmgSz?#w@rN+H=~e@4FJn4lX^$Ap!&fq#Qxqv-5`9| zrNy4yp1epQ-OEti`}0*pE|*qOckyof>@V8mMR=x-iF|jB=3WnNGLB+Bb?m`we!i?@ z4CnH^z=7-IGF&{B9GrXl{->GLL|51RkW}2Vn`B~(;Of{{x%EomF%b4OcRSOB+4qM{sBZn0VM7F5(3 zA}-p7On}wxkQ-}sKSHxA=$R&-kYcu2w$5=(&l&ohk1*5a&1<`$k0;`M9)}{#4CyVC zkdmTUcpvrK-B#gkdplX98!+9-6zx_c8GylsAHZd;6?Eo<`At0ZI#dD*Vgf2~GBOAP z=>U+>&Q_0r7{z>{4+x$G?lZ$G{pQpOZ*~)*Hm7l)H%D_mE&G`$# C*2wt) diff --git a/java/osx/osx-broker/src/test/resources/keystore/client/truststore.jks b/java/osx/osx-broker/src/test/resources/keystore/client/truststore.jks deleted file mode 100644 index 3528c2de87e7edf0220573a5bc75a49d62d98a91..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 976 zcmezO_TO6u1_mY|W(3o0#i>PQsYO6;_oeu{AO_Y5JyQcq1_tJd22ISp22D(*3z(T0 znV49L-JWF{@Un4gwRyCC=VfH%W@Rw&FyuDiWMd9xVH0NZ^Dz`P5C(C$gn2?TO7l`u zi;_}{iYg6-31LqkKL-6p0+ zW>Mn2#vrZ{luJdMni!Rky}`)Jz}&>h&tTBR$i>ve$jGpvVrB39?S1MMkGlDmzg_y{ z;uLij%@ZP~OO=GzJn;0765@nc-|n!L`Mu+GQ$3Fy9(Cp$g^B+!n`u0r=049w z<&C>l&}4Rv7fG^h-#Pa5UgAHwRpF1M%d|5=S_TdYbDjsEx6xU%!r`BBV&!kS73o&% zwtTxaZ%gfxvzmG8ZYSjB3XKJ{&F1n}6vWLnU9&~vMczJ!lqd-ym%mGym>C%u7uOrq z8pr};NLH0a%s_;VLz|6}m6e^D5zb;V;08(av#>BTG4C)i0C5Bv8UM4W8K^XI=Hw?Q z=49j-moy0$B^DQe(?LO|9$2`Yr5>2tAg1%Nh_Q%xMRBG4|B|vi=wP0JCjy5O{_1xY`bzw+^=6jQy)Cu*|?VZ;N+e)Z@&o{sP4YM zYgcB`Tx*lFRlY2(NZPCEjk0zI`>Cu{fa`y|vy32Ns%uLQ#Gnt#>m3vqw8!zdR^4QG6-Fj%t uF8^cayY^1!jcr(ecIxzP|9=Q5?z`K4P%+4P-VIgWLkHK{JuKT=-UR?#2VzP9 diff --git a/java/osx/osx-broker/src/test/resources/keystore/server/identity.jks b/java/osx/osx-broker/src/test/resources/keystore/server/identity.jks deleted file mode 100644 index 5caf90b73575457ef13148d23097bf95c4b18dd2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2607 zcmY+EX*d)L7stoUn2fR0kZoonTV!n6k|mR6GS&=2!YJz)yF_N}iY%q9QA8+=VsdTK z*te`%!?k5;2q~i1eV_Ne_qiX=InVR^pL0I`&=l6QKqe5H0$qSWWnv9u_t=?EFy&LA zBVY>j<1y|`4sR0ngZTKOG8-ykB^@Vf%%>b zon)j-e$rfr26c`CXUk~M7N&q8K&u&;0`3S@G8=D_?^v%GcS0el^{|FQgh~EV9=mNd z4-h_9+yc9B9g5Ad4quag{LQMfp`Lv4&X7Huayq-hT<3h12J8Op9_kYK)3lyxNKfr) zr0RALEh4$QV5ON)C)vZUE!O4%WA!_7)t74#AC}bSxqfHd81KRg*n51c>o6m$>4N3d z-FOI73AVFOJU(~mZd1W{;n)tMM>e1NPwJYQrr|2|51EHp_Kvx zs$KGq9Qk3Id~IqMh}Xi+-pFB%-^f?K)iMO@eDz)y*M;WOFnPFX>TN!$&CQQ9lEUJ|9^4#yF0z={2lKFtf;8!*wqYHgINu8p=g zR}^>l+hHX2{%e^D07rS<>rkBp7yoS#NDE75Y3!cHIn6IEk|Y+0!}QkF_T3bgggfoB z6-GzbZ0~}DS@M)3t9*l=rixv5gjh|UvA5nlP|pkKq~cGSlh6O)z-(;>QI}l3^GNjm z>R~@zXW0j%ILvOQKnUPnoJGu9?YqSZIFH&yrPw030&X4z7f{>1^NR-6zz!MllT#lvxhKF`9U5XF_z2BI!choz0P~|yIT!AV4iDj zwgcv$Tj82euX>lKA8Tjqh`2rMGMG~oJ$agaF;5o=|D{W^-C74};In0{=!lrWSr&bX zvz-v1m=}hOKUXU1 zHwv84Ut7v4+LAp=5AmOn@{$IHp_bzp6USxp&xp6>8nn~{-==>cKb%K|#7m)085K@s znAFY8?&&|^ zU3S~y{2cU*_Q`F(86Uxv)a_ywbYj>1Hazlyc9+s{l6!P^mRz33iIhp)MkHxE>GE6rY=^!?>z^ zHfjlzSwg>lzC?@d0!(L-)Xv+@uGvI@`zT_KJ7806Oxyiz)Hzbl2(+NY$|m-A0#!i= zE%di4p?nG)qHF+v00H0!2>Pp|QU8P|75KnBR_@*eq=G6MjaF7rQ9Q4rqZm5N#TlV^ARy<_QturhbwWHwc+Mj;kLG|WF zMhy%tNYrZx%Tvuou*G3l_~c8D#J*u-#?R~ZsKq@lRFA$egYz((Wg>miNbw~9&^C7&C z=j}vFsM^0Q0J2LeO_+WwL zvor_^Sh0V2+p$K&M~c<#gf9$(VtA^y{o4yFW|;OHB$Bs>YrpJH{^Q#h-3H_0b#y4pnYD z_rN_PGK*a|v{Z%WSNbpG#Xk<2H9x>32YLJRcEp_11nyid~5 zoAYV-f+!u2csS;M#aliFdZ#8x-&Ad!kW4&dW3fWHpWJ9UY*FpmLG?n>a#%pq_<$ky#3VjicLPH_UXE{JX zL1qA$FOuP>fP4MnNMm8-iY9-*S&Ayf7tHq-rJ8%2^Z5yGW?PI)iM&AHOz$foh$%Mz EKd~;uHUIzs diff --git a/java/osx/osx-broker/src/test/resources/keystore/server/server.cer b/java/osx/osx-broker/src/test/resources/keystore/server/server.cer deleted file mode 100644 index 99bb69028a..0000000000 --- a/java/osx/osx-broker/src/test/resources/keystore/server/server.cer +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDjTCCAnWgAwIBAgIEc0bmaTANBgkqhkiG9w0BAQsFADBIMQswCQYDVQQGEwJO -TDEVMBMGA1UEChMMVGh1bmRlcmJlcnJ5MRIwEAYDVQQLEwlBbXN0ZXJkYW0xDjAM -BgNVBAMTBUhha2FuMB4XDTIzMTExNTA5NDUyNloXDTMzMTExMjA5NDUyNlowSDEL -MAkGA1UEBhMCTkwxFTATBgNVBAoTDFRodW5kZXJiZXJyeTESMBAGA1UECxMJQW1z -dGVyZGFtMQ4wDAYDVQQDEwVIYWthbjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC -AQoCggEBALB4qY3vt44neOKLDqftpfjRlCcEKcgUNaUiE6zgSU9aEglawb1WwDas -vhrqoDJB76lJSM8kKrdZOLOQTo0WTonhlhysCeZklJBTgUZaPzvqrdDB9ierGPNS -q43jjwVeMeoD2cedRn1DGxssYDOJNTJyfT1xAQ7/oTx17C56OuXvD8MzRACqAFXa -/g+HQLUD70GXNS5IbAwnA9hxYf/TNijjlkeeRCTsRzpSkwco6GIdhvcIvI3SD8m1 -IPwZRJbMUiowQGCc51PPPCysqED+M2F5+x6oZzqutPbanrR9pM0pbidGyB8ecTMQ -KzadDXhwXp01rLQY6G6+AGRaGBJE/aUCAwEAAaN/MH0wHQYDVR0lBBYwFAYIKwYB -BQUHAwEGCCsGAQUFBwMCMAsGA1UdDwQEAwIDuDAwBgNVHREBAf8EJjAkgglsb2Nh -bGhvc3SCEXJhc3BiZXJyeXBpLmxvY2FshwR/AAABMB0GA1UdDgQWBBRKWgpk//Rk -p1Lpa9hsTC37zpZAWjANBgkqhkiG9w0BAQsFAAOCAQEAkIwbl5rNiWKqgpF/cdM9 -qZRe+vpSleDjuYGtA8GTjKzt9hIwJbvfurppcp07NM13MqcFMPvxHkE8JO9eIJZp -Hnbq4DRAfCY3Xd34AmU7uMCGYYIihKDKic88htbU6u0RGsC29htuH6O6SNAQlPJb -JmzU7YRN72jTiR2ksKTWfF3AKWUQzcTXziZSVrWo4gLc1nEEoD06Lwemj3SjR7Me -8ZqCMy1PbM+9Xo9wMis31n1emEX7ldnloUJlpbSod1yCR2hw2xVR6+KTdqyMKpvJ -u+gxftM+QQMCQ9VpA4JfI4wEkzOkjBpIswQLhcK0uk/Gz4q9lw1dgK/NlZe2//gQ -IQ== ------END CERTIFICATE----- diff --git a/java/osx/osx-broker/src/test/resources/keystore/server/truststore.jks b/java/osx/osx-broker/src/test/resources/keystore/server/truststore.jks deleted file mode 100644 index a474fb28259011a129effe45ccf2bf4440cbbf70..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 922 zcmezO_TO6u1_mY|W(3o0$vK&+c_lz@_oet}YZzD~^h^yb85o$O4Vsw44Vswj7cet1 zGBL3fypa27z{|#|)#lOmotKf3o0Y-9&5+xGlZ`o)g-w{r&&N>EfFH!+66SEsDap)B zNlb?cabpTG3v&dQ=A>5UCgvH)iSrs68yXq{Z8kA8FpCoBH3o5wpj@ii)5NHR>;*ofTgB_xZ&S3!L>TxywH4Z>)$BNnEJRz2Lk` zw9Z7KIlFoNGtJ}emiF$>IhR^?D9Ua!8`G2x`zAI{V2eJu zp}#9L>;S_gt{eBiQ7T;_;k;w!QS0Z|45|ge7^nk>{X#h zTYhd_9WQ^%y4QJu9IG+Ylp^QWoY*7r(U&`pADYN;AV^Ts)tx)J=jeJRmR{bHjq*?Z zUQdp#;+V?5fahFJ$@J_mZw!u{5`7RB&BV;ez_{4gz}r9;7?HB7EMf*CY#iEbjI6Be z%#3grlL0qKnxBP*nTdG^L@ghS7>h_(NzgtQ-<|BAp1ulbOcV=DEMZ%_#q>igq_o3KY3m{HhcYM$xk2W>@6vxlpV0j?8OI{LmiS%_6(P&Tza#+J+AVN@A@O@ogMdX liLtOpeh*rw6L#>RTFVM)n - - - - - + + + + + org.slf4j slf4j-api - com.fasterxml.jackson.core - jackson-databind + com.fasterxml.jackson.core + jackson-databind @@ -57,8 +57,8 @@ - org.apache.zookeeper - zookeeper + org.apache.zookeeper + zookeeper slf4j-log4j12 @@ -93,6 +93,12 @@ 1.9.4 compile + + commons-codec + commons-codec + 1.15 + compile + diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java index 23a53bc5e9..3710490682 100644 --- a/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/config/MetaInfo.java @@ -101,30 +101,44 @@ public class MetaInfo { public static String PROPERTY_TRANSFER_FILE_PATH_PRE = "mapped"+ File.separator+".fate"+ File.separator+"transfer_file"; @Config(confKey = "index.mapped.file.size", pattern = Dict.POSITIVE_INTEGER_PATTERN) public static Integer PROPERTY_INDEX_MAP_FILE_SIZE = 1 << 21; - @Config(confKey = "server.cert.chain.file") - public static String PROPERTY_SERVER_CERT_CHAIN_FILE; - @Config(confKey = "server.private.key.file") - public static String PROPERTY_SERVER_PRIVATE_KEY_FILE; - @Config(confKey = "server.ca.file") - public static String PROPERTY_SERVER_CA_FILE; - @Config(confKey = "server.ssl.store.flag") - public static Boolean PROPERTY_USE_STORE = true; - @Config(confKey = "server.keystore.file") - public static String PROPERTY_SERVER_KEYSTORE_FILE; - @Config(confKey = "server.keystore.file.password") - public static String PROPERTY_SERVER_KEYSTORE_FILE_PASSWORD; - @Config(confKey = "server.trust.keystore.file") - public static String PROPERTY_SERVER_TRUST_KEYSTORE_FILE; - @Config(confKey = "server.trust.keystore.file.password") - public static String PROPERTY_SERVER_TRUST_FILE_PASSWORD; + @Config(confKey = "https.server.cert.chain.file") + public static String PROPERTY_HTTPS_SERVER_CERT_CHAIN_FILE; + @Config(confKey = "https.server.private.key.file") + public static String PROPERTY_HTTPS_SERVER_PRIVATE_KEY_FILE; + @Config(confKey = "https.server.ca.file") + public static String PROPERTY_HTTPS_SERVER_CA_FILE; + @Config(confKey = "https.server.keystore.file") + public static String PROPERTY_HTTPS_SERVER_KEYSTORE_FILE = ""; + @Config(confKey = "https.server.keystore.file.password") + public static String PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD = ""; + @Config(confKey = "https.server.trust.keystore.file") + public static String PROPERTY_HTTPS_SERVER_TRUST_KEYSTORE_FILE = ""; + @Config(confKey = "https.client.trust.keystore.file.password") + public static String PROPERTY_HTTPS_SERVER_TRUST_FILE_PASSWORD = ""; + @Config(confKey = "https.server.hostname.verifier.skip") + public static boolean PROPERTY_HTTPS_HOSTNAME_VERIFIER_SKIP = true; + @Config(confKey = "grpc.server.cert.chain.file") + public static String PROPERTY_GRPC_SERVER_CERT_CHAIN_FILE; + @Config(confKey = "grpc.server.private.key.file") + public static String PROPERTY_GRPC_SERVER_PRIVATE_KEY_FILE; + @Config(confKey = "grpc.server.ca.file") + public static String PROPERTY_GRPC_SERVER_CA_FILE; + @Config(confKey = "grpc.server.keystore.file") + public static String PROPERTY_GRPC_SERVER_KEYSTORE_FILE; + @Config(confKey = "grpc.server.keystore.file.password") + public static String PROPERTY_GRPC_SERVER_KEYSTORE_FILE_PASSWORD; + @Config(confKey = "grpc.server.trust.keystore.file") + public static String PROPERTY_GRPC_SERVER_TRUST_KEYSTORE_FILE; + @Config(confKey = "grpc.server.trust.keystore.file.password") + public static String PROPERTY_GRPC_SERVER_TRUST_FILE_PASSWORD; + @Config(confKey = "custom.local.host") public static String PROPERTY_CUSTOMER_LOCAL_HOST; @Config(confKey = "bind.host") public static String PROPERTY_BIND_HOST = "0.0.0.0"; @Config(confKey = "open.grpc.tls.server", pattern = Dict.BOOLEAN_PATTERN) public static Boolean PROPERTY_OPEN_GRPC_TLS_SERVER = false; - @Config(confKey = "open.grpc.tls.use.keystore", pattern = Dict.BOOLEAN_PATTERN) - public static Boolean PROPERTY_OPEN_TLS_USE_KEYSTORE = false; + @Config(confKey = "grpc.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) public static Integer PROPERTY_GRPC_PORT = 9370; @Config(confKey = "grpc.tls.port", pattern = Dict.POSITIVE_INTEGER_PATTERN) @@ -151,16 +165,10 @@ public class MetaInfo { public static String PROPERTY_HTTP_SSL_TRUST_STORE_TYPE = "PKCS12"; @Config(confKey = "http.ssl.trust.store.provider") public static String PROPERTY_HTTP_SSL_TRUST_STORE_PROVIDER = "SUN"; - @Config(confKey = "http.ssl.key.store.path") - public static String PROPERTY_HTTP_SSL_KEY_STORE_PATH = ""; - @Config(confKey = "http.ssl.key.store.alias") - public static String PROPERTY_HTTP_SSL_KEY_STORE_ALIAS = ""; - @Config(confKey = "http.ssl.key.store.password") - public static String PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD = ""; - @Config(confKey = "http.ssl.trust.store.path") - public static String PROPERTY_HTTP_SSL_TRUST_STORE_PATH = ""; - @Config(confKey = "http.ssl.trust.store.password") - public static String PROPERTY_HTTP_SSL_TRUST_STORE_PASSWORD = ""; + @Config(confKey = "http.ssl.client.key.store.alias") + public static String PROPERTY_HTTP_SSL_CLIENT_KEY_STORE_ALIAS = ""; + @Config(confKey = "http.ssl.server.key.store.alias") + public static String PROPERTY_HTTP_SSL_SERVER_KEY_STORE_ALIAS = ""; @Config(confKey = "http.ssl.hostname.verify") public static Boolean PROPERTY_HTTP_SSL_HOSTNAME_VERIFY = false; @Config(confKey = "http.context.path") @@ -253,6 +261,8 @@ public class MetaInfo { public static Boolean PROPERTY_ROUTER_CHANGE_NEED_TOKEN= false; @Config(confKey = "router.change.token.validator") public static String PROPERTY_ROUTER_CHANGE_TOKEN_VALIDATOR= Dict.DEFAULT; + @Config(confKey = "batch.sink.push.executor.timeout") + public static Integer BATCH_SINK_PUSH_EXECUTOR_TIMEOUT = 60*1000; public static boolean isCluster() { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java index bf3e33445f..05429649d6 100644 --- a/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/constant/Dict.java @@ -275,6 +275,8 @@ public class Dict { public final static String URL = "url"; public final static String USE_SSL = "useSSL"; + public final static String USE_KEYSTORE = "useKeyStore"; + public final static String CA_FILE = "caFile"; public final static String CERT_CHAIN_FILE = "certChainFile"; public final static String PRIVATE_KEY_FILE = "privateKeyFile"; diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/router/RouterInfo.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/router/RouterInfo.java index 3c495635eb..5f87b3f8bc 100644 --- a/java/osx/osx-core/src/main/java/org/fedai/osx/core/router/RouterInfo.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/router/RouterInfo.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; +import org.apache.commons.codec.digest.DigestUtils; import org.fedai.osx.core.context.Protocol; @@ -36,7 +37,8 @@ public boolean equals(Object obj) { return true; } } - + @JsonIgnore + boolean isCycle = false; @JsonInclude(JsonInclude.Include.NON_NULL) private Protocol protocol; @JsonIgnore @@ -73,24 +75,51 @@ public boolean equals(Object obj) { private String trustStorePassword; + public String md5TlsInfo(){ + if (useSSL){ + StringBuilder sb = new StringBuilder(); + if(useKeyStore) { + sb.append(keyStoreFilePath).append(keyStoreFilePath).append(trustStoreFilePath).append(trustStorePassword); + }else{ + sb.append(certChainFile).append(privateKeyFile).append(caFile); + } + return new String( DigestUtils.md5(sb.toString())); + + }else{ + return ""; + } + } public String toKey() { StringBuffer sb = new StringBuffer(); if (Protocol.grpc.equals(protocol)||protocol==null) { sb.append(host).append("_").append(port); - if (useSSL) - sb.append("_").append("tls"); + if (useSSL){ + sb.append("_").append("tls").append(md5TlsInfo()); + } } else { sb.append(url); + if (useSSL){ + sb.append("_").append("tls").append(md5TlsInfo()); + } } return sb.toString(); } @Override public String toString() { - return toKey(); + StringBuffer sb = new StringBuffer(); + if (Protocol.grpc.equals(protocol)||protocol==null) { + sb.append(host).append("_").append(port); + if (useSSL){ + sb.append("_").append("tls"); + } + } else { + sb.append(url); + } + return sb.toString(); } @JsonIgnore public String getResource() { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java index e34137d192..e97c118b03 100644 --- a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OSXCertUtils.java @@ -38,8 +38,8 @@ public static void x509ToPkCS12(Certificate[] chain, Key privateKey, String file try (OutputStream os = new FileOutputStream(filePath)) { KeyStore keyStore = KeyStore.getInstance(type); keyStore.load(null); - keyStore.setKeyEntry(alias, privateKey, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), chain); - keyStore.store(os, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + keyStore.setKeyEntry(alias, privateKey, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray(), chain); + keyStore.store(os, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); } } @@ -67,7 +67,7 @@ public static SSLContext getSSLContext(String caPath, String clientCertPath, Str TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; // Load client certificate KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); return sslContext; } @@ -92,7 +92,7 @@ public static SSLContext getSSLContext(RouterInfo routerInfo) throws Exception { TrustManager[] tm = {OsxX509TrustManager.getInstance(keyStore)}; // Load client certificate KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(keyStore, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + kmf.init(keyStore, MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); sslContext.init(kmf.getKeyManagers(), tm, new SecureRandom()); return sslContext; } @@ -110,17 +110,24 @@ private static KeyStore loadKeyStore(String keyStorePath, String keyStorePasswor public static KeyStore getKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { KeyStore keyStore = KeyStore.getInstance("PKCS12"); keyStore.load(null); - keyStore.setKeyEntry(MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS, importPrivateKey(clientKeyPath), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray(), new Certificate[]{importCert(clientCertPath), importCert(caPath)}); + keyStore.setKeyEntry(MetaInfo.PROPERTY_HTTP_SSL_CLIENT_KEY_STORE_ALIAS, importPrivateKey(clientKeyPath), MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray(), new Certificate[]{importCert(clientCertPath), importCert(caPath)}); + return keyStore; + } + + public static KeyStore getKeyStore2(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null); + keyStore.setKeyEntry(MetaInfo.PROPERTY_HTTP_SSL_SERVER_KEY_STORE_ALIAS, importPrivateKey(clientKeyPath), MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray(), new Certificate[]{importCert(clientCertPath), importCert(caPath)}); return keyStore; } public static KeyStore getTrustStore(String keyStorePath, String trustStoreType) throws Exception { KeyStore keyStore = KeyStore.getInstance(trustStoreType.toUpperCase()); - keyStore.load(new FileInputStream(new File(keyStorePath)), MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray()); + keyStore.load(new FileInputStream(new File(keyStorePath)), MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray()); return keyStore; } - public static String createKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { + /* public static String createKeyStore(String caPath, String clientCertPath, String clientKeyPath) throws Exception { PrivateKey privateKey = importPrivateKey(clientKeyPath); // Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; Certificate[] certificates = {importCert(clientCertPath), importCert(caPath)}; @@ -129,7 +136,7 @@ public static String createKeyStore(String caPath, String clientCertPath, String FileUtils.createNewFile(pfxFile); OSXCertUtils.x509ToPkCS12(certificates, privateKey, pfxPath, MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_ALIAS); return pfxPath; - } + }*/ public static Certificate importCert(String certFile) throws Exception { try (FileInputStream certStream = new FileInputStream(certFile)) { @@ -175,7 +182,7 @@ private static String readFileContent(String filePath) throws Exception { //determine whether the string is null and get the default string character array private static char[] toCharArray(int index, String... str) { - return str.length <= index || str[index] == null ? MetaInfo.PROPERTY_HTTP_SSL_KEY_STORE_PASSWORD.toCharArray() : str[index].toCharArray(); + return str.length <= index || str[index] == null ? MetaInfo.PROPERTY_HTTPS_SERVER_KEYSTORE_FILE_PASSWORD.toCharArray() : str[index].toCharArray(); } public static String getTempStorePath() { diff --git a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java index 84d2cdd17f..f0a2900d35 100644 --- a/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java +++ b/java/osx/osx-core/src/main/java/org/fedai/osx/core/utils/OsxX509TrustManager.java @@ -1,5 +1,6 @@ package org.fedai.osx.core.utils; +import org.fedai.osx.core.config.MetaInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -240,7 +241,7 @@ public static HostnameVerifier2 getInstance() { @Override public boolean verify(String s, SSLSession sslSession) { - return true; + return MetaInfo.PROPERTY_HTTPS_HOSTNAME_VERIFIER_SKIP; } } } diff --git a/java/osx/pom.xml b/java/osx/pom.xml index 060e161ff7..3f7b81a790 100644 --- a/java/osx/pom.xml +++ b/java/osx/pom.xml @@ -18,8 +18,8 @@ 1.8 UTF-8 UTF-8 - 31.1-jre - 1.59.0 + 32.1.3-jre + 1.60.0 1.18.24 3.21.12 0.6.1 @@ -30,8 +30,8 @@ true 2.17.2 1.7.30 - 3.8.0 - 9.4.50.v20221201 + 3.9.1 + 9.4.52.v20230823 2.11.0 5.4.0 3.4.4 @@ -48,9 +48,10 @@ 1.10.0 4.13.2 5.12.1 - 3.8.0 + 3.10.0 6.0.0 0.10.2 + 4.1.100.Final @@ -159,11 +160,6 @@ lombok ${lombok.version} - - com.lmax - disruptor - ${disruptor.version} - org.apache.commons @@ -182,11 +178,6 @@ test - - com.google.guava - guava - ${guava.version} - commons-configuration @@ -194,11 +185,6 @@ ${commons-configuration.version} - - commons-io - commons-io - ${commons-io.version} - org.apache.commons @@ -241,41 +227,12 @@ grpc-api ${grpc.version} - - io.grpc - grpc-core - ${grpc.version} - - - io.grpc - grpc-netty-shaded - ${grpc.version} - - - io.grpc - grpc-protobuf - ${grpc.version} - - - io.grpc - grpc-stub - ${grpc.version} - + com.google.protobuf protobuf-java-util ${protobuf.version} - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - io.grpc grpc-core @@ -286,6 +243,13 @@ grpc-netty-shaded ${grpc.version} + + + io.netty + netty-handler + ${netty.version} + + io.grpc grpc-protobuf From e164cec9f6d46bfe9dbd297896ca2412be5a765c Mon Sep 17 00:00:00 2001 From: v_wbxiongli <740332065@qq.com> Date: Thu, 14 Dec 2023 14:49:55 +0800 Subject: [PATCH 42/42] update doc Signed-off-by: v_wbxiongli <740332065@qq.com> --- doc/2.0/osx/bfia-x.proto | 53 +++++++++++++ doc/2.0/osx/bfia-y.proto | 61 +++++++++++++++ doc/2.0/osx/osx-tls.md | 84 +++++++++++++++++++- doc/2.0/osx/osx.md | 161 +++++++++++++++++++++++++-------------- 4 files changed, 297 insertions(+), 62 deletions(-) create mode 100644 doc/2.0/osx/bfia-x.proto create mode 100644 doc/2.0/osx/bfia-y.proto diff --git a/doc/2.0/osx/bfia-x.proto b/doc/2.0/osx/bfia-x.proto new file mode 100644 index 0000000000..33b1082070 --- /dev/null +++ b/doc/2.0/osx/bfia-x.proto @@ -0,0 +1,53 @@ +/* + * Copyright 2023 The BFIA Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + syntax = "proto3"; + +package org.ppc.ptp; + +// PTP Private transfer protocol +// 通用报头名称编码,4层无Header以二进制填充到报头,7层以Header传输 +// x-ptp-version required 协议版本 +// x-ptp-tech-provider-code required 厂商编码 +// x-ptp-trace-id required 链路追踪ID +// x-ptp-token required 认证令牌 +// x-ptp-source-node-id required 发送端节点编号 +// x-ptp-target-node-id required 接收端节点编号 +// x-ptp-source-inst-id required 发送端机构编号 +// x-ptp-target-inst-id required 接收端机构编号 +// x-ptp-session-id required 通信会话号,全网唯一 + +// 通信传输层输入报文编码 +message Inbound { + map metadata = 1; // 报头,可选,预留扩展,Dict,序列化协议由通信层统一实现 + bytes payload = 2; // 报文,上层通信内容承载,序列化协议由上层基于SPI可插拔 +} + +// 通信传输层输出报文编码 +message Outbound { + map metadata = 1; // 报头,可选,预留扩展,Dict,序列化协议由通信层统一实现 + bytes payload = 2; // 报文,上层通信内容承载,序列化协议由上层基于SPI可插拔 + string code = 3; // 状态码 + string message = 4; // 状态说明 +} + +// 互联互通如果使用异步传输协议作为标准参考,Header会复用metadata传输互联互通协议报头,且metadata中会传输异步场景下的消息相关属性 +// 互联互通如果使用其他协议作为参考标准,Header会复用metadata传输互联互通协议报头 +// 互联互通如果使用GRPC作为参考标准,Header会复用HTTP2的报头传输互联互通协议报头 + +service PrivateTransferProtocol { + rpc transport (stream Inbound) returns (stream Outbound); + rpc invoke (Inbound) returns (Outbound); +} diff --git a/doc/2.0/osx/bfia-y.proto b/doc/2.0/osx/bfia-y.proto new file mode 100644 index 0000000000..ca0d973996 --- /dev/null +++ b/doc/2.0/osx/bfia-y.proto @@ -0,0 +1,61 @@ +/* + * Copyright 2023 The BFIA Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto3"; + +package org.ppc.ptp; + +// PTP Private transfer protocol +// 通用报头名称编码,4层无Header以二进制填充到报头,7层以Header传输 +// x-ptp-tech-provider-code: required 厂商编码 +// x-ptp-trace-id: required 链路追踪ID +// x-ptp-token required 认证令牌 +// x-ptp-session-id required 通信会话号,全网唯一 +// x-ptp-target-node-id required 接收端节点编号,全网唯一 +// x-ptp-target-inst-id optional 接收端机构编号,全网唯一 + +message PeekInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 +} + +message PopInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间,默认120s +} + +message PushInbound{ + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + bytes payload = 2; // 二进制报文 + map metadata = 3; // optional 保留参数,用于扩展性 +} + +message ReleaseInbound { + string topic = 1; // optional 会话主题,相同信道具有唯一性,用于同一信道的传输隔离 + int32 timeout = 2; // optional 阻塞超时时间,默认120s +} + +message TransportOutbound { + map metadata = 1; // 可选,预留扩展,Dict,序列化协议由通信层统一实现 + bytes payload = 2; // 二进制报文 + string code = 3; // 状态码 + string message = 4; // 状态说明 +} + +service PrivateTransferTransport { + rpc peek (PeekInbound) returns (TransportOutbound); + rpc pop (PopInbound) returns (TransportOutbound); + rpc push (PushInbound) returns (TransportOutbound); + rpc release (ReleaseInbound) returns (TransportOutbound); +} diff --git a/doc/2.0/osx/osx-tls.md b/doc/2.0/osx/osx-tls.md index fe097a69fa..2521b617e6 100644 --- a/doc/2.0/osx/osx-tls.md +++ b/doc/2.0/osx/osx-tls.md @@ -130,11 +130,89 @@ two-way TSL: #### 完成以上步骤您将生成如下证书: -​ server文件夹包含: identity.jks 、server.cer、truststore.jks。 +​ server文件夹包含: identity.jks 、server.cer、truststore.jks; 其中 identity.jks为私钥密码箱,truststore.jks 信任证书密码箱,server.cer为服务器证书(如果使用密码箱,无需配置此文件)。 -​ client文件夹包含: identity.jks 、client.cer、truststore.jks。 +​ client文件夹包含: identity.jks 、client.cer、truststore.jks;其中 identity.jks为私钥密码箱,truststore.jks 信任证书密码箱,client.cer为客户端证书(如果使用密码箱,无需配置此文件)。 ## 2)方式二:单独文件存储私钥、证书、信任证书方式 -#### 生成ca.key、ca.crt、client.crt、client.csr、client.key、client.pem、server.crt、server.csr、server.key、server.pem 命令如下: +``` +ca.key +生成CA自己的私钥 root ca.key +# openssl genrsa -out ca.key 2048 + +ca.crt +根据CA自己的私钥生成自签发的数字证书,该证书里包含CA自己的公钥。 +# openssl req -x509 -new -nodes -key ca.key -subj "/CN=osx" -days 5000 -out ca.crt + +server.key +服务端的私钥和数字证书(由自CA签发) +生成服务端私钥 +# openssl genrsa -out server.key 2048 +将其转换成 pkcs8 格式,供java程序使用 +#openssl pkcs8 -topk8 -inform PEM -outform PEM -in server.key -out server_pkcs8.key -nocrypt + +server.csr +生成Certificate Sign Request,CSR,证书签名请求。 +# openssl req -new -key server.key -subj "/CN=osx" -out server.csr + +server.crt +自CA用自己的CA私钥对服务端提交的csr进行签名处理,得到服务端的数字证书server.crt +您的服务器还使用 IP 地址,请根据需要添加: +openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out server.crt -days 5000 \ + -extfile <(printf "subjectAltName=DNS:grpcpro1.com,IP:your_server_ip") + +要对客户端数字证书进行校验,首先客户端需要先有自己的证书。我们以上面的例子为基础,生成客户端的私钥与证书。 +client.key +# openssl genrsa -out client.key 2048 +将其转换成 pkcs8 格式,供java程序使用 +# openssl pkcs8 -topk8 -inform PEM -outform PEM -in client.key -out client_pkcs8.key -nocrypt + +client.csr +# openssl req -new -key client.key -subj "/CN=osx" -out client.csr + +client.crt +# openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out client.crt -days 5000 + + +``` + +产物: +CA: +私钥文件 ca.key +数字证书 ca.crt + +Server: +私钥文件 server.key、server_pkcs8.key(实际配置此格式私钥) +数字证书 server.crt + +client: +私钥文件 client.key、client_pkcs8.key(实际配置此格式私钥) +数字证书 client.crt + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/2.0/osx/osx.md b/doc/2.0/osx/osx.md index 23bc9b6e9f..47bc71107e 100644 --- a/doc/2.0/osx/osx.md +++ b/doc/2.0/osx/osx.md @@ -40,7 +40,7 @@ FATE1.X维护了多套通信架构,包括eggroll、spark+pulsar+nginx 、spark - 路由配置与原eggroll基本一致,降低了移植难度 - + ## 组件设计: @@ -74,9 +74,11 @@ FATE1.X维护了多套通信架构,包括eggroll、spark+pulsar+nginx 、spark +#### 互联互通协议解析: +传输组件之间接口详解参见[bfia-x.proto](./bfia-x.proto) - +算法与传输组件之间接口详解参见 [bfia-y.proto](./bfia-y.proto) #### 传输模式: @@ -191,7 +193,7 @@ eggroll.cluster.manager.port = 4670 ## 部署: 1. 部署机器需要安装jdk1.8+ -2. 解压osx.tar.gz +2. 解压osx.tar.gz 3. 进入部署目录,执行sh service.sh start @@ -202,13 +204,13 @@ eggroll.cluster.manager.port = 4670 ![alllog](../images/alllog.png) - flow.log : 记录请求日志 +flow.log : 记录请求日志 - broker.log : 记录所有日志 +broker.log : 记录所有日志 - broker-error: 记录异常日志 +broker-error: 记录异常日志 - broker-debug: debug日志默认不开启 +broker-debug: debug日志默认不开启 正常启动后可以在broker.log中看到如下日志: @@ -278,117 +280,158 @@ eggroll.cluster.manager.port = 4670 ### 证书相关: -#### OSX相关配置: +1) 方式一:使用keystore密码箱存储私钥、证书、信任证书方式(PS:生成证书命令详情文件:[osx-tls.md](./osx-tls.md)) + +OSX相关配置: - grpcs: - broker.property配置(使用keystore方式,即方式1): + server端broker.property配置(使用keystore方式,即方式1): ``` - # 打开grpcs server开关 + # 打开grpcs server开关,并指定grpcs端口 + grpc.tls.port=9371 open.grpc.tls.server= true - # 是否使用keystore方式(默认为false) - open.grpc.tls.use.keystore= true - #server端密码箱路径以及密码 - server.keystore.file= - server.keystore.file.password= - #server端信任证书keystore路径及密码 - server.trust.keystore.file= - server.trust.keystore.file.password= + # server端密码箱路径以及密码 + grpc.server.keystore.file= server/identity.jks + grpc.server.keystore.file.password= XXXXXX + # server端信任证书keystore路径及密码 + grpc.server.trust.keystore.file= server/truststore.jks + grpc.server.trust.keystore.file.password= XXXXXX ``` - 相关client端路由表配置: + 相关client端路由表配置,在对应的partyid下面添加如下内容: ``` "default": [ { "protocol": "grpc", - "keyStoreFile": "D:/webank/osx/test3/client/identity.jks", - "keyStorePassword": "123456", - "trustStoreFile": "D:/webank/osx/test3/client/truststore.jks", - "trustStorePassword": "123456", + "keyStoreFile": "私钥密码箱绝对路径", + "keyStorePassword": "私钥密码箱密码", + "trustStoreFile": "信任证书绝对路径", + "trustStorePassword": "信任证书密码", "useSSL": true, + "useKeyStore" : true, "port": 9885, "ip": "127.0.0.1" } ] ``` - - - https: - broker.property配置(使用keystore方式,即方式1): + server端broker.property配置(使用keystore方式,即方式1): ``` - # grpcs端口 + # https端口 https.port=8092 - # 打开grpcs server开关 + # 打开https server开关 open.https.server= true - # server端密码箱路径以及密码 - server.keystore.file= - server.keystore.file.password= - # server端信任证书keystore路径及密码 - server.trust.keystore.file= - server.trust.keystore.file.password= + # keystore 类型 + http.ssl.trust.store.type=JKS + # https server端密码箱路径以及密码 + https.server.keystore.file= server/identity.jks + https.server.keystore.file.password= XXXXXX + # https server端信任证书keystore路径及密码 + https.server.trust.keystore.file= server/truststore.jks + https.server.trust.keystore.file.password= XXXXXX ``` - - 相关client端路由表配置: - + + 相关client端路由表配置,在对应的partyid下面添加如下内容: + ``` - + "default": [ + { + "protocol": "http", + "url": "https://ip:8092/v1/interconn/chan/invoke", + "keyStoreFile": "私钥密码箱绝对路径", + "keyStorePassword": "私钥密码箱密码", + "trustStoreFile": "信任证书绝对路径", + "trustStorePassword": "信任证书密码", + "useSSL": true, + "useKeyStore" : true + } + ] ``` - - -2)方式二:单独文件存储私钥、证书、信任证书方式 -​ 生成命令: -​ +2)方式二:单独文件存储私钥、证书、信任证书方式(PS:生成证书命令详情文件:[osx-tls.md](./osx-tls.md)) -#### OSX相关配置: +OSX相关配置: - grpcs: - broker.property配置(使用非keystore方式,即方式2): + server端broker.property配置(使用非keystore方式,即方式2): ``` - # 打开grpcs server开关 + # 打开grpcs server开关,并指定grpcs端口 + grpc.tls.port=9371 open.grpc.tls.server= true - # 是否使用keystore方式 - open.grpc.tls.use.keystore= false - - - + # server端公钥 + grpc.server.cert.chain.file= server.crt + # sever端私钥 + grpc.server.private.key.file= server.pem + # 信任证书 + grpc.server.ca.file= ca.crt ``` - 相关client端路由表配置: + 相关client端路由表配置,在对应的partyid下面添加如下内容: ``` - + "default": [ + { + "protocol": "grpc", + "certChainFile": "公钥路径", + "privateKeyFile": "私钥路径", + "caFile": "信任证书路径", + "useSSL": true, + "useKeyStore" : false, + "port": 9371, + "ip": "127.0.0.1" + } + ] ``` - + - https: - broker.property配置(使用非keystroke方式,即方式2): + server端broker.property配置(使用非keystore方式,即方式2): ``` - + # https端口 + https.port=8092 + # 打开https server开关 + open.https.server= true + # server端公钥 + https.server.cert.chain.file= server.crt + # sever端私钥 + https.server.private.key.file= server.pem + # 信任证书 + https.server.ca.file= ca.crt ``` - 相关client端路由表配置: + 相关client端路由表配置,在对应的partyid下面添加如下内容: ``` - + "default": [ + { + "protocol": "http", + "url": "https://ip:8092/v1/interconn/chan/invoke", + "certChainFile": "公钥路径", + "privateKeyFile": "私钥路径", + "caFile": "信任证书路径", + "useSSL": true, + "useKeyStore" : false + } + ] ``` - +